From 8c5284c13def703948386b4628e2e6a40ba44b2c Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Wed, 3 May 2017 18:24:43 -0800 Subject: [PATCH 01/43] track variable's persistent momery in the variable op; remove tracking in allocate_persistent calls since we are going to track persistent memory of resources in resource ops. We will soon use resource variables instead of variables, we will need to track that as well. Change: 155036823 --- tensorflow/core/framework/op_kernel.cc | 16 ---------------- tensorflow/core/kernels/variable_ops.h | 12 ++++++++++++ 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 3d913cdaf0c..6fad379b760 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -656,22 +656,6 @@ Status OpKernelContext::allocate_persistent(DataType type, *out_tensor = out_persistent->AccessTensor(this); } } - if (track_allocations() && persistent.TotalBytes() > 0) { - // TODO(yuefengz): some allocators allocate memory even if the requested - // size is 0. - Allocator* a = get_allocator(attr); - if (a->TracksAllocationSizes()) { - int64 alloc_size = - a->AllocatedSize(const_cast(persistent.tensor_data().data())); - int64 alloc_id = - a->AllocationId(const_cast(persistent.tensor_data().data())); - if (allocate_on_host(attr)) { - record_host_persistent_memory_allocation(alloc_size, alloc_id); - } else { - record_device_persistent_memory_allocation(alloc_size, alloc_id); - } - } - } return s; } diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h index 8c173a4ba30..25b17b26c8d 100644 --- a/tensorflow/core/kernels/variable_ops.h +++ b/tensorflow/core/kernels/variable_ops.h @@ -76,6 +76,18 @@ class VariableOp : public OpKernel { // As long as the resource manager hasn't been cleared the ref we return // here is valid because it owns a ref on var. ctx->set_output_ref(0, var->mu(), var->tensor()); + if (ctx->track_allocations() && var->tensor()->IsInitialized()) { + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + if (ctx->allocate_on_host(attr)) { + ctx->record_host_persistent_memory_allocation( + var->tensor()->AllocatedBytes()); + } else { + ctx->record_device_persistent_memory_allocation( + var->tensor()->AllocatedBytes()); + } + } var->Unref(); } From 66496e0445b1cd1cf661a7bc5b40bf4e3378012d Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Wed, 3 May 2017 21:21:16 -0800 Subject: [PATCH 02/43] Convert VadalidationMonitors to hooks in Experiment for fit/train call. Change: 155045865 --- tensorflow/contrib/learn/python/learn/experiment.py | 4 ++++ tensorflow/contrib/learn/python/learn/experiment_test.py | 7 ++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 602d33e5f9b..85d45aef7ac 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -647,6 +647,10 @@ class Experiment(object): if _sentinel is not None: raise ValueError("_call_train should be called with keyword args only") + # Estimator in core cannot work with monitors. We need to convert them + # to hooks. For Estimator in contrib, it is converted internally. So, it is + # safe to convert for both cases. + hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator) if self._core_estimator_used: return self._estimator.train(input_fn=input_fn, steps=steps, diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index 4b5f3a195ce..9ecfc732998 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -24,7 +24,6 @@ import time from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import experiment -from tensorflow.contrib.learn.python.learn import monitors from tensorflow.contrib.learn.python.learn import run_config from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib @@ -461,7 +460,8 @@ class ExperimentTest(test.TestCase): self.assertEqual(1, est.eval_count) self.assertEqual(1, len(est.monitors)) self.assertEqual([noop_hook], est.eval_hooks) - self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor)) + self.assertTrue(isinstance(est.monitors[0], + session_run_hook.SessionRunHook)) def test_train_hooks_extend_does_not_mutate_input_hooks(self): for est in self._estimators_for_tests(): @@ -563,7 +563,8 @@ class ExperimentTest(test.TestCase): self.assertEqual(1, est.export_count) self.assertEqual(1, len(est.monitors)) self.assertEqual([noop_hook], est.eval_hooks) - self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor)) + self.assertTrue(isinstance(est.monitors[0], + session_run_hook.SessionRunHook)) def test_train_and_evaluate_with_no_eval_during_training(self): for est in self._estimators_for_tests(): From 34e31cd154dfa896272690413abe2a709b49d42d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 3 May 2017 22:38:54 -0800 Subject: [PATCH 03/43] [XLA] Improve error message for concatenate, expose ComputationBuilder::first_error() Change: 155049706 --- tensorflow/compiler/xla/client/computation_builder.h | 8 ++++++++ tensorflow/compiler/xla/service/shape_inference.cc | 7 +++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 87ceb43d1fe..6af69eeec12 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -668,6 +668,14 @@ class ComputationBuilder { // then Build() should be used instead. Computation BuildAndNoteError(); + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // ComputationDataHandle and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + private: using PopulateLiteral = std::function; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 338d63f1a00..b2ef8ed486b 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -244,8 +244,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "cannot concatenate arrays with different ranks: %lld vs %lld", - ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape)); + "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "(%s)", + ShapeUtil::Rank(*arg_shape), + ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), + ShapeUtil::HumanString(*shape).c_str()); } if (arg_shape->element_type() != shape->element_type()) { return InvalidArgument( From b661c668c26e31ab1bd6f74d751390b511c50739 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 3 May 2017 22:48:59 -0800 Subject: [PATCH 04/43] Remove duplicated documentation for SavedModel CLI. Change: 155050138 --- .../docs_src/programmers_guide/index.md | 5 ++ tensorflow/python/tools/saved_model_cli.py | 64 +------------------ 2 files changed, 7 insertions(+), 62 deletions(-) diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md index 309b39451fd..acdca2bad4f 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/programmers_guide/index.md @@ -39,6 +39,11 @@ trained graph. The following guide details `MetaGraph` objects: * @{$meta_graph$Exporting and Importing a MetaGraph}. +`SavedModel` is the universal serialization format for Tensorflow models. TensorFlow provides SavedModel CLI (command-line interface) as a tool to inspect and execute a MetaGraph in a SavedModel. The detailed usages and examples are +documented in the following guide: + + * @{$saved_model_cli$SavedModel CLI (Command-Line Interface)}. + To learn about the TensorFlow versioning scheme, consult the following two guides: diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 2fea29d961e..c9c56a50143 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -14,68 +14,8 @@ # ============================================================================== """Command-line interface to inspect and execute a graph in a SavedModel. -If TensorFlow is installed on your system through pip, the 'saved_model_cli' -binary can be invoked directly from command line. - -At a high level, SavedModel CLI allows users to both inspect and execute -computations on a MetaGraphDef in a SavedModel. These are done through `show` -and `run` commands. Following is the usage of the two commands. SavedModel -CLI will also display these information with -h option. - -'show' command usage: saved_model_cli show [-h] --dir DIR [--tag_set TAG_SET] - [--signature_def SIGNATURE_DEF_KEY] -Examples: -To show all available tag-sets in the SavedModel: - $saved_model_cli show --dir /tmp/saved_model - -To show all available SignatureDef keys in a MetaGraphDef specified by its -tag-set: - $saved_model_cli show --dir /tmp/saved_model --tag_set serve -For a MetaGraphDef with multiple tags in the tag-set, all tags must be passed -in, separated by ',': - $saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu - -To show all inputs and outputs TensorInfo for a specific SignatureDef specified -by the SignatureDef key in a MetaGraphDef: - $saved_model_cli show --dir /tmp/saved_model --tag_set serve - --signature_def serving_default -Example output: - The given SavedModel SignatureDef contains the following input(s): - inputs['input0'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - inputs['input1'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - The given SavedModel SignatureDef contains the following output(s): - outputs['output'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - Method name is: tensorflow/serving/regress - -To show all available information in the SavedModel: - $saved_model_cli show --dir /tmp/saved_model --all - -usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def - SIGNATURE_DEF_KEY [--inputs INPUTS] - [--input_exprs INPUT_EXPRS] [--outdir OUTDIR] - [--overwrite] [--tf_debug] - -Examples: -To run input tensors from files through a MetaGraphDef and save the output -tensors to files: - $saved_model_cli run --dir /tmp/saved_model --tag_set serve - --signature_def serving_default --inputs x=/tmp/124.npz - --input_exprs 'x2=np.ones((6,2))' --outdir /tmp/out - -To observe the intermediate Tensor values in the runtime graph, use the ---tf_debug flag, e.g.: - $saved_model_cli run --dir /tmp/saved_model --tag_set serve - --signature_def serving_default --inputs 'x=/tmp/124.npz;x2=/tmp/123.npy' - --outdir /tmp/out --tf_debug - -To build this tool from source, run: - $bazel build tensorflow/python/tools:saved_model_cli +For detailed usages and examples, please refer to: +https://www.tensorflow.org/programmers_guide/saved_model_cli """ From 19d74ef55cf5979d133af6371e2bb644a09072e5 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Thu, 4 May 2017 01:28:15 -0800 Subject: [PATCH 05/43] [XLA] Add a target hook to force input layout equal to output layout in layout_assignment Currently unused. Change: 155058967 --- tensorflow/compiler/xla/service/layout_assignment.cc | 3 ++- tensorflow/compiler/xla/service/layout_assignment.h | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 5e7bd4a7ce8..d413621cfe2 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -705,7 +705,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape()) && ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && + if ((instruction->IsElementwiseOnOperand(operand_no) || + InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) && !ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape())) { diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 61dc7b12075..4f586c334dc 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -248,6 +248,15 @@ class LayoutAssignment : public HloPassInterface { return Status::OK(); } + // This method can be overriden to mark instructions as requiring the operands + // to have the same layout as the result, for performance or correctness. This + // will propagate constraints through the instruction from the result into the + // operands. + virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction) { + return false; + } + // Construct contraints and assign layouts to all instructions in the // computation satisfying the given ComputationLayout. Layouts constraints are // added, then propagated until all LogicalBuffers in the computation are From 0374fd18c2c6bd19dba265469604f711fe6a91c5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 4 May 2017 04:55:19 -0800 Subject: [PATCH 06/43] Change Unique op to use gtl::FlatMap instead of std::unordered_map<>. Change: 155070869 --- tensorflow/core/kernels/unique_op.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index f5d4fcec84c..d50e2060acf 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "tensorflow/core/framework/op_kernel.h" @@ -21,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { @@ -50,8 +50,7 @@ class UniqueOp : public OpKernel { {0}, 1, input.shape(), &idx)); auto idx_vec = idx->template vec(); - std::unordered_map uniq; - uniq.reserve(2 * N); + gtl::FlatMap uniq(N); for (int64 i = 0, j = 0; i < N; ++i) { auto it = uniq.insert(std::make_pair(Tin(i), j)); idx_vec(i) = it.first->second; From 42c7659eddaaeea13ffac6688e1acd56147abaf1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 07:04:55 -0800 Subject: [PATCH 07/43] Add `categorical_column_with_vocabulary_file`. Move lookup_ops implementation from tensorflow/contrib/lookup to tensorflow/python/feature_column. Change: 155079825 --- tensorflow/contrib/cmake/tf_python.cmake | 1 + tensorflow/contrib/lookup/BUILD | 11 +- tensorflow/contrib/lookup/__init__.py | 2 +- tensorflow/python/BUILD | 3 +- tensorflow/python/feature_column/BUILD | 34 ++ .../python/feature_column/feature_column.py | 233 +++++++++- .../feature_column/feature_column_test.py | 408 +++++++++++++++++- .../feature_column}/lookup_ops.py | 6 +- .../testdata/warriors_vocabulary.txt | 5 + .../testdata/wire_vocabulary.txt | 3 + .../tools/pip_package/pip_smoke_test.py | 1 + 11 files changed, 666 insertions(+), 41 deletions(-) rename tensorflow/{contrib/lookup => python/feature_column}/lookup_ops.py (99%) create mode 100644 tensorflow/python/feature_column/testdata/warriors_vocabulary.txt create mode 100644 tensorflow/python/feature_column/testdata/wire_vocabulary.txt diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 53ebfbb57de..ad3b29c8ea5 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -203,6 +203,7 @@ add_python_module("tensorflow/python/estimator") add_python_module("tensorflow/python/estimator/export") add_python_module("tensorflow/python/estimator/inputs") add_python_module("tensorflow/python/estimator/inputs/queues") +add_python_module("tensorflow/python/feature_column") add_python_module("tensorflow/python/framework") add_python_module("tensorflow/python/grappler") add_python_module("tensorflow/python/kernel_tests") diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index b3316ee8c4f..5966c86dfb9 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -13,19 +13,10 @@ py_library( name = "lookup_py", srcs = [ "__init__.py", - "lookup_ops.py", ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops_gen", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:string_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", + "//tensorflow/python/feature_column:lookup_ops", ], ) diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py index dbd64cf0421..a5fcdc7b42d 100644 --- a/tensorflow/contrib/lookup/__init__.py +++ b/tensorflow/contrib/lookup/__init__.py @@ -47,7 +47,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import -from tensorflow.contrib.lookup.lookup_ops import * +from tensorflow.python.feature_column.lookup_ops import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5e938c73f5a..817d157da29 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -82,6 +82,7 @@ py_library( "//third_party/py/numpy", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column", + "//tensorflow/python/feature_column:lookup_ops", "//tensorflow/python/ops/losses", "//tensorflow/python/ops/distributions", "//tensorflow/python/saved_model", @@ -1021,7 +1022,7 @@ tf_gen_op_wrapper_private_py( require_shape_functions = True, visibility = [ "//learning/brain/python/ops:__pkg__", - "//tensorflow/contrib/lookup:__pkg__", + "//tensorflow/python/feature_column:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", ], ) diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index d5eb20e997c..d7342738457 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -29,6 +29,7 @@ py_library( srcs = ["feature_column.py"], srcs_version = "PY2AND3", deps = [ + ":lookup_ops", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", "//tensorflow/python:init_ops", @@ -44,14 +45,47 @@ py_library( ], ) +filegroup( + name = "vocabulary_testdata", + srcs = [ + "testdata/warriors_vocabulary.txt", + "testdata/wire_vocabulary.txt", + ], +) + py_test( name = "feature_column_test", srcs = ["feature_column_test.py"], + data = [":vocabulary_testdata"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ ":feature_column", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", "//tensorflow/python:training", ], ) + +# TODO(ptucker,yleon): Move along with 3p/tf/contrib/lookup. +# Test is still in 3p/tf/contrib/lookup. +py_library( + name = "lookup_ops", + srcs = [ + "lookup_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops_gen", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:string_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index a96052a3ae5..33bed3abcf1 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -121,6 +121,7 @@ from __future__ import print_function import abc import collections +from tensorflow.python.feature_column import lookup_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -331,7 +332,9 @@ def numeric_column(key, ``` Args: - key: A string providing key to look up corresponding `Tensor`. + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. shape: An iterable of integers specifies the shape of the `Tensor`. An integer can be given which means a single dimension `Tensor` with given width. The `Tensor` representing the column will have the shape of @@ -443,22 +446,22 @@ def categorical_column_with_hash_bucket(key, ```python keywords = categorical_column_with_hash_bucket("keywords", 10K) - all_feature_columns = [keywords, ...] - linear_prediction = make_linear_model(features, all_feature_columns) + linear_prediction = make_linear_model(features, [keywords, ...]) # or keywords_embedded = embedding_column(keywords, 16) - all_feature_columns = [keywords_embedded, ...] - dense_tensor = make_input_layer(features, all_feature_columns) + dense_tensor = make_input_layer(features, [keywords_embedded, ...]) ``` Args: - key: A string providing key to look up corresponding `Tensor`. + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. hash_bucket_size: An int > 1. The number of buckets. dtype: The type of features. Only string and integer types are supported. Returns: - A `_CategoricalColumnHashed`. + A `_HashedCategoricalColumn`. Raises: ValueError: `hash_bucket_size` is not greater than 1. @@ -476,7 +479,100 @@ def categorical_column_with_hash_bucket(key, raise ValueError('dtype must be string or integer. ' 'dtype: {}, column_name: {}'.format(dtype, key)) - return _CategoricalColumnHashed(key, hash_bucket_size, dtype) + return _HashedCategoricalColumn(key, hash_bucket_size, dtype) + + +def categorical_column_with_vocabulary_file( + key, vocabulary_file, vocabulary_size, num_oov_buckets=0, + default_value=None, dtype=dtypes.string): + """Creates a `_CategoricalColumn` with vocabulary file configuration. + + Use this when your inputs are in string or integer format, and you have a + vocabulary file that maps each value to an integer ID. By default, + out-of-vocabulary values are ignored. Use either (but not both) of + `num_oov_buckets` and `default_value` to specify how to include + out-of-vocabulary values. + + Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values + can be represented by `-1` for int and `''` for string. Note that these values + are independent of the `default_value` argument. + + Example with `num_oov_buckets`: + File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state + abbreviation. All inputs with values in that file are assigned an ID 0-49, + corresponding to its line number. All other values are hashed and assigned an + ID 50-54. + ```python + states = categorical_column_with_vocabulary_file( + key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=50, + num_oov_buckets=5) + linear_prediction = make_linear_model(features, [states, ...]) + ``` + + Example with `default_value`: + File '/us/states.txt' contains 51 lines - the first line is 'XX', and the + other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX' + in input, and other values missing from the file, will be assigned ID 0. All + others are assigned the corresponding line number 1-50. + ```python + states = categorical_column_with_vocabulary_file( + key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=51, + default_value=0) + linear_prediction, _, _ = make_linear_model(features, [states, ...]) + + And to make an embedding with either: + ```python + dense_tensor = make_input_layer(features, [embedding_column(states, 3),...]) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + vocabulary_file: The vocabulary file name. + vocabulary_size: Number of the elements in the vocabulary. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of + the input value. A positive `num_oov_buckets` can not be specified with + `default_value`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to -1. This can not be specified with a positive + `num_oov_buckets`. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_CategoricalColumn` with vocabulary file configuration. + + Raises: + ValueError: `vocabulary_file` is missing. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is not a non-negative integer. + ValueError: `dtype` is neither string nor integer. + """ + if not vocabulary_file: + raise ValueError('Missing vocabulary_file in {}.'.format(key)) + # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. + # TODO(ptucker): Should we fail for vocabulary_size==1? + if (vocabulary_size is None) or (vocabulary_size < 1): + raise ValueError('Invalid vocabulary_size in {}.'.format(key)) + if num_oov_buckets: + if default_value is not None: + raise ValueError( + 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( + key)) + if num_oov_buckets < 0: + raise ValueError('Invalid num_oov_buckets {} in {}.'.format( + num_oov_buckets, key)) + if dtype != dtypes.string and not dtype.is_integer: + raise ValueError('Invalid dtype {} in {}.'.format(dtype, key)) + return _VocabularyCategoricalColumn( + key=key, + vocabulary_file=vocabulary_file, + vocabulary_size=vocabulary_size, + num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets, + default_value=-1 if default_value is None else default_value, + dtype=dtype) class _FeatureColumn(object): @@ -764,6 +860,67 @@ class _LazyBuilder(object): return transformed +# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py +def _shape_offsets(shape): + """Returns moving offset for each dimension given shape.""" + offsets = [] + for dim in reversed(shape): + if offsets: + offsets.append(dim * offsets[-1]) + else: + offsets.append(dim) + offsets.reverse() + return offsets + + +# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py +def _to_sparse_input(input_tensor, ignore_value=None): + """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. + + If `input_tensor` is already a `SparseTensor`, just return it. + + Args: + input_tensor: A string or integer `Tensor`. + ignore_value: Entries in `dense_tensor` equal to this value will be + absent from the resulting `SparseTensor`. If `None`, default value of + `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). + + Returns: + A `SparseTensor` with the same shape as `input_tensor`. + + Raises: + ValueError: when `input_tensor`'s rank is `None`. + """ + input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( + input_tensor) + if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): + return input_tensor + with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)): + input_rank = input_tensor.get_shape().ndims + if input_rank is None: + # TODO(b/32318825): Implement dense_to_sparse_tensor for undefined rank. + raise ValueError('Undefined input_tensor shape.') + if ignore_value is None: + ignore_value = '' if input_tensor.dtype == dtypes.string else -1 + dense_shape = math_ops.cast(array_ops.shape(input_tensor), dtypes.int64) + indices = array_ops.where(math_ops.not_equal( + input_tensor, math_ops.cast(ignore_value, input_tensor.dtype))) + # Flattens the tensor and indices for use with gather. + flat_tensor = array_ops.reshape(input_tensor, [-1]) + flat_indices = indices[:, input_rank - 1] + # Computes the correct flattened indices for 2d (or higher) tensors. + if input_rank > 1: + higher_dims = indices[:, :input_rank - 1] + shape_offsets = array_ops.stack( + _shape_offsets(array_ops.unstack(dense_shape)[1:])) + offsets = math_ops.reduce_sum( + math_ops.multiply(higher_dims, shape_offsets), + reduction_indices=[1]) + flat_indices = math_ops.add(flat_indices, offsets) + values = array_ops.gather(flat_tensor, flat_indices) + return sparse_tensor_lib.SparseTensor(indices, values, dense_shape) + + def _check_feature_columns(feature_columns): if isinstance(feature_columns, dict): raise ValueError('Expected feature_columns to be iterable, found dict.') @@ -951,7 +1108,7 @@ def _check_default_value(shape, default_value, dtype, key): `shape`. dtype: defines the type of values. Default value is `tf.float32`. Must be a non-quantized, real integer or floating point type. - key: A string providing key to look up corresponding `Tensor`. + key: Column name, used only for error messages. Returns: A tuple which will be used as default value. @@ -994,9 +1151,9 @@ def _check_default_value(shape, default_value, dtype, key): default_value, dtype, key)) -class _CategoricalColumnHashed( +class _HashedCategoricalColumn( _CategoricalColumn, - collections.namedtuple('_CategoricalColumnHashed', + collections.namedtuple('_HashedCategoricalColumn', ['key', 'hash_bucket_size', 'dtype'])): """see `categorical_column_with_hash_bucket`.""" @@ -1009,7 +1166,7 @@ class _CategoricalColumnHashed( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - input_tensor = inputs.get(self.key) + input_tensor = _to_sparse_input(inputs.get(self.key)) if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') @@ -1045,6 +1202,58 @@ class _CategoricalColumnHashed( return _CategoricalColumn.IdWeightPair(inputs.get(self), None) +class _VocabularyCategoricalColumn( + _CategoricalColumn, collections.namedtuple('_VocabularyCategoricalColumn', ( + 'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype', + 'default_value' + ))): + """See `categorical_column_with_vocabulary_file`.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_config(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input(inputs.get(self.key)) + + if self.dtype.is_integer != input_tensor.dtype.is_integer: + raise ValueError( + 'Column dtype and SparseTensors dtype must be compatible. ' + 'key: {}, column dtype: {}, tensor dtype: {}'.format( + self.key, self.dtype, input_tensor.dtype)) + + key_dtype = self.dtype + if input_tensor.dtype.is_integer: + # `index_table_from_file` requires 64-bit integer keys. + key_dtype = dtypes.int64 + input_tensor = math_ops.to_int64(input_tensor) + elif input_tensor.dtype != dtypes.string: + raise ValueError('input tensors dtype must be string or integer. ' + 'dtype: {}, column_name: {}'.format( + input_tensor.dtype, self.key)) + + return lookup_ops.index_table_from_file( + vocabulary_file=self.vocabulary_file, + num_oov_buckets=self.num_oov_buckets, + vocab_size=self.vocabulary_size, + default_value=self.default_value, + key_dtype=key_dtype, + name='{}_lookup'.format(self.key)).lookup(input_tensor) + + @property + def _num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return self.vocabulary_size + self.num_oov_buckets + + def _get_sparse_tensors( + self, inputs, weight_collections=None, trainable=None): + return _CategoricalColumn.IdWeightPair(inputs.get(self), None) + + # TODO(zakaria): Move this to embedding_ops and make it public. def _safe_embedding_lookup_sparse(embedding_weights, sparse_ids, diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index bc626533104..d85142abcfb 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -28,6 +28,7 @@ from tensorflow.python.client import session from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import data_flow_ops @@ -552,7 +553,7 @@ class BucketizedColumnTest(test.TestCase): self.assertAllClose([[81.], [141.]], predictions.eval()) -class SparseColumnHashedTest(test.TestCase): +class HashedCategoricalColumnTest(test.TestCase): def test_defaults(self): a = fc.categorical_column_with_hash_bucket('aaa', 10) @@ -578,11 +579,14 @@ class SparseColumnHashedTest(test.TestCase): def test_deep_copy(self): """Tests deepcopy of categorical_column_with_hash_bucket.""" - column = fc.categorical_column_with_hash_bucket('aaa', 10) - column_copy = copy.deepcopy(column) - self.assertEqual('aaa', column_copy.name) - self.assertEqual(10, column_copy.hash_bucket_size) - self.assertEqual(dtypes.string, column_copy.dtype) + original = fc.categorical_column_with_hash_bucket('aaa', 10) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + self.assertEqual(10, column.hash_bucket_size) + # pylint: disable=protected-access + self.assertEqual(10, column._num_buckets) + # pylint: enable=protected-access + self.assertEqual(dtypes.string, column.dtype) def test_parse_config(self): a = fc.categorical_column_with_hash_bucket('aaa', 10) @@ -681,14 +685,45 @@ class SparseColumnHashedTest(test.TestCase): def test_get_sparse_tensors(self): hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) - wire_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - builder = fc._LazyBuilder({'wire': wire_tensor}) - self.assertEqual( - builder.get(hashed_sparse), - hashed_sparse._get_sparse_tensors(builder).id_tensor) + builder = fc._LazyBuilder({ + 'wire': sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + }) + id_weight_pair = hashed_sparse._get_sparse_tensors(builder) + self.assertIsNone(id_weight_pair.weight_tensor) + self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor) + + def test_get_sparse_tensors_dense_input(self): + hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) + builder = fc._LazyBuilder({ + 'wire': (('omar', ''), ('stringer', 'marlo')) + }) + id_weight_pair = hashed_sparse._get_sparse_tensors(builder) + self.assertIsNone(id_weight_pair.weight_tensor) + self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor) + + def test_make_linear_model(self): + wire_column = fc.categorical_column_with_hash_bucket('wire', 4) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 3: wire_var[3] = 4 + # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6 + self.assertAllClose(((4.,), (6.,)), predictions.eval()) def get_linear_model_bias(): @@ -1158,5 +1193,350 @@ class MakeInputLayerTest(test.TestCase): self.assertAllClose([[1., 3.]], net2.eval()) +class VocabularyCategoricalColumnTest(test.TestCase): + + def setUp(self): + super(VocabularyCategoricalColumnTest, self).setUp() + + # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22 + self._warriors_vocabulary_file_name = test.test_src_dir_path( + 'python/feature_column/testdata/warriors_vocabulary.txt') + self._warriors_vocabulary_size = 5 + + # Contains strings, character names from 'The Wire': omar, stringer, marlo + self._wire_vocabulary_file_name = test.test_src_dir_path( + 'python/feature_column/testdata/wire_vocabulary.txt') + self._wire_vocabulary_size = 3 + + def _assert_sparse_tensor_value(self, expected, actual): + self.assertEqual(np.int64, np.array(actual.indices).dtype) + self.assertAllEqual(expected.indices, actual.indices) + + self.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + self.assertAllEqual(expected.values, actual.values) + + self.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + self.assertAllEqual(expected.dense_shape, actual.dense_shape) + + def test_defaults(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.string) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_all_constructor_args(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, + num_oov_buckets=4, dtype=dtypes.int32) + # pylint: disable=protected-access + self.assertEqual(7, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_deep_copy(self): + """Tests deepcopy of categorical_column_with_hash_bucket.""" + original = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, + num_oov_buckets=4, dtype=dtypes.int32) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(7, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_vocabulary_file_none(self): + with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=None, vocabulary_size=3) + + def test_vocabulary_file_empty_string(self): + with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='', vocabulary_size=3) + + def test_invalid_vocabulary_file(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'): + with self.test_session(): + data_flow_ops.tables_initializer().run() + + def test_invalid_vocabulary_size(self): + with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=None) + with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=-1) + with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=0) + + def test_too_large_vocabulary_size(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size + 1) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'): + with self.test_session(): + data_flow_ops.tables_initializer().run() + + def test_invalid_num_oov_buckets(self): + with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path', vocabulary_size=3, + num_oov_buckets=-1) + + def test_invalid_dtype(self): + with self.assertRaisesRegexp(ValueError, 'Invalid dtype'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path', vocabulary_size=3, + dtype=dtypes.float64) + + def test_invalid_buckets_and_default_value(self): + with self.assertRaisesRegexp( + ValueError, 'both num_oov_buckets and default_value'): + fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=100, + default_value=2) + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + self._assert_sparse_tensor_value( + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': (('marlo', ''), ('skywalker', 'omar')) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + self._assert_sparse_tensor_value( + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_default_value_in_vocabulary(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + default_value=2) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + self._assert_sparse_tensor_value( + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 2, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_with_oov_buckets(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=100) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (1, 2)), + values=('marlo', 'skywalker', 'omar', 'heisenberg'), + dense_shape=(2, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + self._assert_sparse_tensor_value( + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 33, 0, 62), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_small_vocabulary_size(self): + # 'marlo' is the last entry in our vocabulary file, so be setting + # `vocabulary_size` to 1 less than number of entries in file, we take + # 'marlo' out of the vocabulary. + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size - 1) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + self._assert_sparse_tensor_value( + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((-1, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=(11, 100, 30, 22), + dense_shape=(3, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + self._assert_sparse_tensor_value( + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_dense_input(self): + default_value = -100 + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32, + default_value=default_value) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22)) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + self._assert_sparse_tensor_value( + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((2, default_value, 0, 4), dtype=np.int64), + dense_shape=(3, 3)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_with_oov_buckets(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32, + num_oov_buckets=100) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=(11, 100, 30, 22), + dense_shape=(3, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + self._assert_sparse_tensor_value( + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 60, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_make_linear_model(self): + wire_column = fc.categorical_column_with_vocabulary_file( + key='wire', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=1) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 + self.assertAllClose(((3.,), (5.,)), predictions.eval()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/python/feature_column/lookup_ops.py similarity index 99% rename from tensorflow/contrib/lookup/lookup_ops.py rename to tensorflow/python/feature_column/lookup_ops.py index 9dc7414cd07..13a67fa5183 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/python/feature_column/lookup_ops.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Lookup table Operations.""" -# pylint: disable=g-bad-name +"""Lookup table operations.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -608,7 +608,7 @@ class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])): __slots__ = () -FastHashSpec = HasherSpec("fasthash", None) +FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name class StrongHashSpec(HasherSpec): diff --git a/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt b/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt new file mode 100644 index 00000000000..6c917fa6999 --- /dev/null +++ b/tensorflow/python/feature_column/testdata/warriors_vocabulary.txt @@ -0,0 +1,5 @@ +30 +35 +11 +23 +22 diff --git a/tensorflow/python/feature_column/testdata/wire_vocabulary.txt b/tensorflow/python/feature_column/testdata/wire_vocabulary.txt new file mode 100644 index 00000000000..32c6b5692a0 --- /dev/null +++ b/tensorflow/python/feature_column/testdata/wire_vocabulary.txt @@ -0,0 +1,3 @@ +omar +stringer +marlo diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 61c3fe55405..0438ce68469 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -45,6 +45,7 @@ BLACKLIST = [ "//tensorflow/python:compare_test_proto_py", "//tensorflow/core:image_testdata", "//tensorflow/core/kernels/cloud:bigquery_reader_ops", + "//tensorflow/python/feature_column:vocabulary_testdata", "//tensorflow/python:framework/test_file_system.so", # contrib "//tensorflow/contrib/session_bundle:session_bundle_half_plus_two", From 1e4899035ed82c85f6c85b7349211528f161402c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 08:29:59 -0800 Subject: [PATCH 08/43] Use six.string_types instead of str in estimator/export. Change: 155087824 --- tensorflow/python/estimator/export/export.py | 6 +- .../python/estimator/export/export_output.py | 4 +- .../estimator/export/export_output_test.py | 29 ++++++++ .../python/estimator/export/export_test.py | 67 ++++++++++++++++++- 4 files changed, 100 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 37a98cf4815..a1ecd794df6 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -23,6 +23,8 @@ import collections import os import time +import six + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -56,7 +58,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver', if not isinstance(features, dict): features = {_SINGLE_FEATURE_DEFAULT_NAME: features} for name, tensor in features.items(): - if not isinstance(name, str): + if not isinstance(name, six.string_types): raise ValueError('feature keys must be strings: {}.'.format(name)) if not (isinstance(tensor, ops.Tensor) or isinstance(tensor, sparse_tensor.SparseTensor)): @@ -68,7 +70,7 @@ class ServingInputReceiver(collections.namedtuple('ServingInputReceiver', if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} for name, tensor in receiver_tensors.items(): - if not isinstance(name, str): + if not isinstance(name, six.string_types): raise ValueError( 'receiver_tensors keys must be strings: {}.'.format(name)) if not isinstance(tensor, ops.Tensor): diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 69be0f687c1..49bcd06d504 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc +import six + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -171,7 +173,7 @@ class PredictOutput(ExportOutput): 'Prediction outputs must be given as a dict of string to Tensor; ' 'got {}'.format(outputs)) for key, value in outputs.items(): - if not isinstance(key, str): + if not isinstance(key, six.string_types): raise ValueError( 'Prediction output key must be a string; got {}.'.format(key)) if not isinstance(value, ops.Tensor): diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 27a088e551c..035a9a143e6 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -22,7 +22,9 @@ from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.estimator.export import export_output as export_output_lib +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -197,6 +199,33 @@ class ExportOutputTest(test.TestCase): signature_constants.CLASSIFY_METHOD_NAME) self.assertEqual(actual_signature_def, expected_signature_def) + def test_predict_output_constructor(self): + """Tests that no errors are raised when input is expected.""" + outputs = { + "output0": constant_op.constant([0]), + u"output1": constant_op.constant([1]), + } + export_output_lib.PredictOutput(outputs) + + def test_predict_output_outputs_invalid(self): + with self.assertRaisesRegexp( + ValueError, + "Prediction outputs must be given as a dict of string to Tensor"): + export_output_lib.PredictOutput(constant_op.constant([0])) + + with self.assertRaisesRegexp( + ValueError, + "Prediction output key must be a string"): + export_output_lib.PredictOutput({1: constant_op.constant([0])}) + + with self.assertRaisesRegexp( + ValueError, + "Prediction output value must be a Tensor"): + export_output_lib.PredictOutput({ + "prediction1": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + }) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index fdd924f2e1c..7946bd88ba0 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -28,13 +28,11 @@ from tensorflow.core.example import example_pb2 from tensorflow.python.estimator.export import export from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import constant_op -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -43,6 +41,69 @@ from tensorflow.python.saved_model import signature_def_utils class ExportTest(test_util.TensorFlowTestCase): + def test_serving_input_receiver_constructor(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + export.ServingInputReceiver(features, receiver_tensors) + + def test_serving_input_receiver_features_invalid(self): + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + + with self.assertRaisesRegexp(ValueError, "features must be defined"): + export.ServingInputReceiver( + features=None, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "feature keys must be strings"): + export.ServingInputReceiver( + features={1: constant_op.constant([1])}, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature feature1 must be a Tensor or SparseTensor"): + export.ServingInputReceiver( + features={"feature1": [1]}, + receiver_tensors=receiver_tensors) + + def test_serving_input_receiver_receiver_tensors_invalid(self): + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors must be defined"): + export.ServingInputReceiver( + features=features, + receiver_tensors=None) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors keys must be strings"): + export.ServingInputReceiver( + features=features, + receiver_tensors={ + 1: array_ops.placeholder(dtypes.string, name="example0")}) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensor example1 must be a Tensor"): + export.ServingInputReceiver( + features=features, + receiver_tensors={"example1": [1]}) + def test_single_feature_single_receiver(self): feature = constant_op.constant(5) receiver_tensor = array_ops.placeholder(dtypes.string) From 074ede9a2371e188d27790f68b7b7ead82d8ef1d Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Thu, 4 May 2017 08:43:15 -0800 Subject: [PATCH 09/43] Adjust getting started guide to use a training and eval data set. Change: 155089162 --- .../docs_src/get_started/get_started.md | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/tensorflow/docs_src/get_started/get_started.md b/tensorflow/docs_src/get_started/get_started.md index b52adc3790a..00cc10cd347 100644 --- a/tensorflow/docs_src/get_started/get_started.md +++ b/tensorflow/docs_src/get_started/get_started.md @@ -372,25 +372,36 @@ features = [tf.contrib.layers.real_valued_column("x", dimension=1)] estimator = tf.contrib.learn.LinearRegressor(feature_columns=features) # TensorFlow provides many helper methods to read and set up data sets. -# Here we use `numpy_input_fn`. We have to tell the function how many batches +# Here we use two data sets: one for training and one for evaluation +# We have to tell the function how many batches # of data (num_epochs) we want and how big each batch should be. -x = np.array([1., 2., 3., 4.]) -y = np.array([0., -1., -2., -3.]) -input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x}, y, batch_size=4, +x_train = np.array([1., 2., 3., 4.]) +y_train = np.array([0., -1., -2., -3.]) +x_eval = np.array([2., 5., 8., 1.]) +y_eval = np.array([-1.01, -4.1, -7, 0.]) +input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x_train}, y_train, + batch_size=4, num_epochs=1000) +eval_input_fn = tf.contrib.learn.io.numpy_input_fn( + {"x":x_eval}, y_eval, batch_size=4, num_epochs=1000) -# We can invoke 1000 training steps by invoking the `fit` method and passing the +# We can invoke 1000 training steps by invoking the method and passing the # training data set. estimator.fit(input_fn=input_fn, steps=1000) -# Here we evaluate how well our model did. In a real example, we would want -# to use a separate validation and testing data set to avoid overfitting. -print(estimator.evaluate(input_fn=input_fn)) +# Here we evaluate how well our model did. +train_loss = estimator.evaluate(input_fn=input_fn) +eval_loss = estimator.evaluate(input_fn=eval_input_fn) +print("train loss: %r"% train_loss) +print("eval loss: %r"% eval_loss) ``` When run, it produces ``` - {'global_step': 1000, 'loss': 1.9650059e-11} + train loss: {'global_step': 1000, 'loss': 4.3049088e-08} + eval loss: {'global_step': 1000, 'loss': 0.0025487561} ``` +Notice how our eval data has a higher loss, but it is still close to zero. +That means we are learning properly. ### A custom model @@ -432,19 +443,25 @@ def model(features, labels, mode): train_op=train) estimator = tf.contrib.learn.Estimator(model_fn=model) -# define our data set -x = np.array([1., 2., 3., 4.]) -y = np.array([0., -1., -2., -3.]) -input_fn = tf.contrib.learn.io.numpy_input_fn({"x": x}, y, 4, num_epochs=1000) +# define our data sets +x_train = np.array([1., 2., 3., 4.]) +y_train = np.array([0., -1., -2., -3.]) +x_eval = np.array([2., 5., 8., 1.]) +y_eval = np.array([-1.01, -4.1, -7, 0.]) +input_fn = tf.contrib.learn.io.numpy_input_fn({"x": x_train}, y_train, 4, num_epochs=1000) # train estimator.fit(input_fn=input_fn, steps=1000) -# evaluate our model -print(estimator.evaluate(input_fn=input_fn, steps=10)) +# Here we evaluate how well our model did. +train_loss = estimator.evaluate(input_fn=input_fn) +eval_loss = estimator.evaluate(input_fn=eval_input_fn) +print("train loss: %r"% train_loss) +print("eval loss: %r"% eval_loss) ``` When run, it produces -```python -{'loss': 5.9819476e-11, 'global_step': 1000} +``` +train loss: {'global_step': 1000, 'loss': 4.9380226e-11} +eval loss: {'global_step': 1000, 'loss': 0.01010081} ``` Notice how the contents of the custom `model()` function are very similar From 3bee923c93f9624ce3abf8d55173be66a7755545 Mon Sep 17 00:00:00 2001 From: Patrick Nguyen Date: Thu, 4 May 2017 08:48:51 -0800 Subject: [PATCH 10/43] Use $opt defined in the loop rather than raw string. Fixes #9651. Change: 155089799 --- configure | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configure b/configure index dce59586ab5..d7dde98292a 100755 --- a/configure +++ b/configure @@ -357,7 +357,7 @@ fi # Append CC optimization flags to bazel.rc for opt in $CC_OPT_FLAGS; do - write_to_bazelrc 'build:opt --cxxopt=$opt --copt=$opt' + write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt" done # Run the gen_git_source to create links where bazel can track dependencies for From 65044bc25981e4e060ad5c34d9a520a0561775c3 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Thu, 4 May 2017 08:48:52 -0800 Subject: [PATCH 11/43] Add an option to not convert layout if GEMM is used internally in Conv2D, Conv2DBackpropFilter, and Conv2DBackpropInput, because in such cases, NHWC is usually faster than NCHW. The cost of enabling this option is the overhead of more non-cancellable layout conversion nodes. We added auto tuning to choose a better option by estimating the overhead using the number of added layout conversion nodes. Don't Convert the layout for Sum, because reduction along dimension 0, 2, 3 (in NCHW) is about 10x slower than along 0, 1, 2 (in NHWC). Change: 155089805 --- tensorflow/core/grappler/op_types.cc | 10 + tensorflow/core/grappler/op_types.h | 2 + tensorflow/core/grappler/optimizers/BUILD | 17 ++ .../grappler/optimizers/layout_optimizer.cc | 190 +++++++++++++++--- .../grappler/optimizers/layout_optimizer.h | 6 + .../optimizers/layout_optimizer_test.cc | 147 ++++++++++++++ 6 files changed, 341 insertions(+), 31 deletions(-) create mode 100644 tensorflow/core/grappler/optimizers/layout_optimizer_test.cc diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index bafbcc200c4..64bdd910773 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -18,6 +18,11 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool IsConcat(const NodeDef& node) { + const auto op = node.op(); + return op == "Concat" || op == "ConcatV2"; +} + bool IsDequeueOp(const NodeDef& node) { static const std::set dequeue_ops = { "QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2", @@ -30,6 +35,11 @@ bool IsPlaceholder(const NodeDef& node) { return op == "Placeholder" || op == "PlaceholderV2"; } +bool IsTranspose(const NodeDef& node) { + const auto op = node.op(); + return op == "Transpose"; +} + bool IsVariable(const NodeDef& node) { const auto op = node.op(); return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 2f58835628d..4f2bb2bc056 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -21,8 +21,10 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool IsConcat(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); +bool IsTranspose(const NodeDef& node); bool IsVariable(const NodeDef& node); } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index e3b36c84123..5f30dfbaa26 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -205,11 +205,28 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", ], ) +cc_test( + name = "layout_optimizer_test", + srcs = ["layout_optimizer_test.cc"], + deps = [ + ":layout_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + cc_library( name = "meta_optimizer", srcs = ["meta_optimizer.cc"], diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 9570ec17d05..5fec89b6987 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -68,8 +69,7 @@ std::set GetOpsFormatAgnostic() { "Slice", "SquaredDifference", "Squeeze", - "Sub", - "Sum"}; + "Sub"}; return ops_format_agnostic; } @@ -110,9 +110,9 @@ class NodeProcessor { } protected: - bool IsDimsN(NodeDef* node, int n) const { - if (node->attr().find("_output_shapes") != node->attr().end()) { - auto shape = node->attr().at("_output_shapes").list().shape(0); + bool IsDimsN(const NodeDef& node, int n) const { + if (node.attr().find("_output_shapes") != node.attr().end()) { + auto shape = node.attr().at("_output_shapes").list().shape(0); if (shape.dim_size() == n) { return true; } @@ -120,7 +120,7 @@ class NodeProcessor { return false; } - bool IsDimsFour(NodeDef* node) const { return IsDimsN(node, 4); } + bool IsDimsFour(const NodeDef& node) const { return IsDimsN(node, 4); } bool IsNHWC() const { if (node_->attr().find("data_format") != node_->attr().end()) { @@ -145,7 +145,7 @@ class NodeProcessor { } virtual bool ShouldProcess() const { - return IsNHWC() && IsDimsFour(node_) && HasOutputs(); + return IsNHWC() && IsDimsFour(*node_) && HasOutputs(); } void UpdateAttrDataFormat() { @@ -268,6 +268,8 @@ class NodeProcessor { for (const auto& output : outputs) { string node_name_NCHWToNHWC = strings::StrCat( kTransposeNCHWToNHWC, "-", node_->name(), "-", output->name()); + // TODO (yaozhang): handle the rare case where node A is connected to more + // than one input of node B. auto it = std::find_if(output->mutable_input()->begin(), output->mutable_input()->end(), [this](const string& input) { @@ -341,7 +343,7 @@ class BiasAddGradProcessor : public NodeProcessor { bool ShouldProcess() const override { auto input = node_map_->GetNode(node_->input(0)); if (input) { - if ((IsNHWC() && IsDimsFour(input)) || IsNodeNCHWToNHWC(input->name())) { + if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) { return true; } } @@ -351,13 +353,89 @@ class BiasAddGradProcessor : public NodeProcessor { Status AddLayoutTransposeToOutputs() override { return Status::OK(); } }; -class Conv2DBackpropFilterProcessor : public NodeProcessor { +class Conv2DProcessor : public NodeProcessor { public: - Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node, - NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + Conv2DProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool no_gemm) + : NodeProcessor(graph, node, node_map), no_gemm_(no_gemm) {} protected: + bool ShouldProcess() const override { + return IsNHWC() && IsDimsFour(*node_) && HasOutputs() && + (!IsGemmUsed() || no_gemm_); + } + + TensorShapeProto GetShape(const string& input_name) const { + string node_name; + int output_pos; + node_name = ParseNodeName(input_name, &output_pos); + NodeDef* node = node_map_->GetNode(node_name); + if (node->attr().find("_output_shapes") != node->attr().end()) { + return node->attr().at("_output_shapes").list().shape(output_pos); + } + TensorShapeProto shape; + return shape; + } + + bool IsStrideOne() const { + if (node_->attr().find("strides") != node_->attr().end()) { + auto list = node_->attr().at("strides").list(); + return list.i(1) == 1 && list.i(2) == 1; + } + return false; + } + + bool IsValidPadding() const { + if (node_->attr().find("padding") != node_->attr().end()) { + auto padding = node_->attr().at("padding").s(); + return padding == "VALID"; + } + return false; + } + + // The logic inside this function is based on the internal implementation of + // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus + // needs to be updated accordingly if the internal implementation changes. + bool IsGemmUsed(const TensorShapeProto& filter_shape, + const TensorShapeProto& input_shape) const { + if (filter_shape.dim_size() == 4) { + if (filter_shape.dim(0).size() == 1 && filter_shape.dim(1).size() == 1 && + IsStrideOne()) { + return true; + } + } + if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) { + if (input_shape.dim(1).size() == filter_shape.dim(0).size() == 1 && + input_shape.dim(2).size() == filter_shape.dim(1).size() && + IsValidPadding()) { + return true; + } + } + return false; + } + + virtual bool IsGemmUsed() const { + auto filter_shape = GetShape(node_->input(1)); + auto input_shape = GetShape(node_->input(0)); + return IsGemmUsed(filter_shape, input_shape); + } + + bool no_gemm_; +}; + +class Conv2DBackpropFilterProcessor : public Conv2DProcessor { + public: + Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node, + NodeMap* node_map, bool no_gemm) + : Conv2DProcessor(graph, node, node_map, no_gemm) {} + + protected: + bool IsGemmUsed() const override { + auto filter_shape = GetShape(node_->name()); + auto input_shape = GetShape(node_->input(0)); + return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape); + } + std::vector GetInputPos() const override { std::vector input_pos = {0, 2}; return input_pos; @@ -370,17 +448,24 @@ class Conv2DBackpropFilterProcessor : public NodeProcessor { void UpdateAttrShape() override {} }; -class Conv2DBackpropInputProcessor : public NodeProcessor { +class Conv2DBackpropInputProcessor : public Conv2DProcessor { public: Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node, - NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + NodeMap* node_map, bool no_gemm) + : Conv2DProcessor(graph, node, node_map, no_gemm) {} protected: + bool IsGemmUsed() const override { + auto filter_shape = GetShape(node_->input(1)); + auto input_shape = GetShape(node_->name()); + return Conv2DProcessor::IsGemmUsed(filter_shape, input_shape); + } + std::vector GetInputPos() const override { std::vector input_pos = {2}; return input_pos; } + Status CustomizedProcessing() override { NodeDef* node = node_map_->GetNode(node_->input(0)); return UpdateAttrValue(node); @@ -418,7 +503,7 @@ class AgnosticNodeProcessor : public NodeProcessor { protected: bool ShouldProcess() const override { - return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC(); + return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC(); } bool IsNodeAfterNCHWToNHWC() const { @@ -467,7 +552,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { - return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && (Is4DOperateWithND(4) || Is4DOperateWithScalar() || Is4DOperateWithVector()); } @@ -484,10 +569,10 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { auto input0 = node_map_->GetNode(node_->input(0)); auto input1 = node_map_->GetNode(node_->input(1)); if (input0 && input1) { - return (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) && + return (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) && ((n == 4) - ? (IsDimsFour(input1) || IsNodeNCHWToNHWC(input1->name())) - : IsDimsN(input1, n)); + ? (IsDimsFour(*input1) || IsNodeNCHWToNHWC(input1->name())) + : IsDimsN(*input1, n)); } return false; } @@ -571,7 +656,7 @@ class ConcatProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { - return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + return IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && IsAlongDimC(); } @@ -739,7 +824,7 @@ class SqueezeProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { - return IsDimsN(node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + return IsDimsN(*node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW(); } @@ -790,7 +875,7 @@ class SumProcessor : public AgnosticNodeProcessor { bool ShouldProcess() const override { auto input0 = node_map_->GetNode(node_->input(0)); return HasOutputs() && IsNodeAfterNCHWToNHWC() && - (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) && + (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) && IsAlongDimNHW(); } @@ -825,10 +910,21 @@ class SumProcessor : public AgnosticNodeProcessor { } }; +struct TuningConfig { + // If true, do not use the NHWC GEMM implementation. When filter size is + // one or filter size is equal to input image size, + // the NHWC implementation of Conv2D, Conv2DBackpropInput, and + // Conv2DBackpropFilter will use a specialized GEMM implementation, which is + // usually faster than the NCHW implementation. The downside is that this + // might result in more non-cancellable layout conversion nodes (implemented + // by the Tranpose op). + bool no_gemm; +}; + class DataLayoutOptimizer { public: - explicit DataLayoutOptimizer(GraphDef* graph) - : graph_(graph), node_map_(graph_) {} + explicit DataLayoutOptimizer(GraphDef* graph, TuningConfig config) + : graph_(graph), node_map_(graph_), config_(config) {} Status Optimize() { LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size(); @@ -908,12 +1004,15 @@ class DataLayoutOptimizer { } else if (node->op().compare("BiasAddGrad") == 0) { node_processor.reset( new BiasAddGradProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Conv2D") == 0) { + node_processor.reset( + new Conv2DProcessor(graph_, node, &node_map_, config_.no_gemm)); } else if (node->op().compare("Conv2DBackpropFilter") == 0) { - node_processor.reset( - new Conv2DBackpropFilterProcessor(graph_, node, &node_map_)); + node_processor.reset(new Conv2DBackpropFilterProcessor( + graph_, node, &node_map_, config_.no_gemm)); } else if (node->op().compare("Conv2DBackpropInput") == 0) { - node_processor.reset( - new Conv2DBackpropInputProcessor(graph_, node, &node_map_)); + node_processor.reset(new Conv2DBackpropInputProcessor( + graph_, node, &node_map_, config_.no_gemm)); } else if (node->op().compare("FusedBatchNormGrad") == 0) { node_processor.reset( new FusedBatchNormGradProcessor(graph_, node, &node_map_)); @@ -1025,17 +1124,46 @@ class DataLayoutOptimizer { GraphDef* graph_; NodeMap node_map_; + TuningConfig config_; }; +int GetNumTranspose(const GraphDef& graph) { + int number = 0; + for (const auto& node : graph.node()) { + if (IsTranspose(node)) { + number++; + } + } + LOG(INFO) << "Number of Transpose nodes: " << number; + return number; +} + Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { - if (GetNumAvailableGPUs() < 1) { + if (num_gpus_ == 0) { + num_gpus_ = GetNumAvailableGPUs(); + } + if (num_gpus_ < 1) { // LayoutOptimizer is currently only tuned for GPU. return Status::OK(); } + *output = item.graph; - DataLayoutOptimizer layout_optimizer(output); + TuningConfig config; + config.no_gemm = false; + DataLayoutOptimizer layout_optimizer(output, config); auto status = layout_optimizer.Optimize(); + + // This is based on an empirical observation that if the introduced Transpose + // nodes is more than 30, not using GEMM implementation would result in better + // performance. + if (status.ok() && GetNumTranspose(*output) > 30) { + *output = item.graph; + config.no_gemm = true; + DataLayoutOptimizer layout_optimizer(output, config); + status = layout_optimizer.Optimize(); + } + if (!status.ok()) { *output = item.graph; } diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h index 66dec17a35c..1bd6f9544b1 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.h +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.h @@ -29,11 +29,17 @@ class LayoutOptimizer : public GraphOptimizer { string name() const override { return "layout"; }; + // This is for testing only. + void set_num_gpus(int num_gpus) { num_gpus_ = num_gpus; }; + Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) override; void Feedback(Cluster* cluster, const GrapplerItem& item, const GraphDef& optimize_output, double result) override; + + private: + int num_gpus_ = 0; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc new file mode 100644 index 00000000000..be38ca1a69e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/layout_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +void AddOutputShape(Node* node, const TensorShape& shape) { + std::vector output_shapes; + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + output_shapes.push_back(shape_proto); + node->AddAttr("_output_shapes", output_shapes); +} + +class LayoutOptimizerTest : public ::testing::Test { + protected: + Output SimpleConv(tensorflow::Scope* s, int input_size, int filter_size, + const string& padding) { + int batch_size = 128; + int input_height = input_size; + int input_width = input_size; + int input_depth = 3; + int filter_count = 2; + int stride = 1; + TensorShape input_shape( + {batch_size, input_height, input_width, input_depth}); + Tensor input_data(DT_FLOAT, input_shape); + test::FillIota(&input_data, 1.0f); + Output input = + ops::Const(s->WithOpName("Input"), Input::Initializer(input_data)); + AddOutputShape(input.node(), input_shape); + + TensorShape filter_shape( + {filter_size, filter_size, input_depth, filter_count}); + Tensor filter_data(DT_FLOAT, filter_shape); + test::FillIota(&filter_data, 1.0f); + Output filter = + ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); + AddOutputShape(filter.node(), filter_shape); + + Output conv = ops::Conv2D(s->WithOpName("Conv2D"), input, filter, + {1, stride, stride, 1}, padding); + AddOutputShape(conv.node(), input_shape); + return conv; + } +}; + +TEST_F(LayoutOptimizerTest, FilterSizeIsOne) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 1, "SAME"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_FALSE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +TEST_F(LayoutOptimizerTest, FilterSizeNotOne) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 1, "SAME"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_FALSE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 2, "VALID"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_FALSE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 2, "SAME"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_TRUE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv(&s, 2, 3, "VALID"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map(&output); + EXPECT_TRUE( + node_map.GetNode("LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Input")); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow From 641c9824d4c08b5d7c6ae4c3f26b0607f0dea619 Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Thu, 4 May 2017 08:56:01 -0800 Subject: [PATCH 12/43] Make contrib real_valued_column cross compatible with core feature_column builders. Change: 155090692 --- tensorflow/contrib/layers/BUILD | 1 + .../layers/python/layers/feature_column.py | 28 +++++++++++++++++-- .../python/layers/feature_column_ops_test.py | 17 +++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index aba8eabe10c..fe661a56250 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -108,6 +108,7 @@ tf_custom_op_py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/feature_column", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index d6d5bf2294f..04fe2370d1d 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -136,8 +136,10 @@ from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.ops import bucketization_op from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_py +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -1497,9 +1499,12 @@ def _real_valued_var_len_column(column_name, is_sparse) -class _RealValuedColumn(_FeatureColumn, collections.namedtuple( - "_RealValuedColumn", - ["column_name", "dimension", "default_value", "dtype", "normalizer"])): +class _RealValuedColumn( + _FeatureColumn, + fc_core._DenseColumn, # pylint: disable=protected-access + collections.namedtuple( + "_RealValuedColumn", + ["column_name", "dimension", "default_value", "dtype", "normalizer"])): """Represents a real valued feature column also known as continuous features. Instances of this class are immutable. The dictionary returned by InputBuilder @@ -1569,6 +1574,23 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple( def _to_dense_tensor(self, input_tensor): return input_tensor + @property + def _variable_shape(self): + return tensor_shape.TensorShape((self.dimension)) + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + del weight_collections + del trainable + return inputs.get(self) + + def _transform_feature(self, inputs): + return math_ops.to_float( + self._normalized_input_tensor(inputs.get(self.name))) + + @property + def _parse_example_config(self): + return self.config + def real_valued_column(column_name, dimension=1, diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 632836fee44..b2dad0162e9 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -610,6 +611,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval()) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnWithMultiDimensions(self): real_valued = feature_column.real_valued_column("price", 2) @@ -620,6 +625,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval()) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnSparse(self): sparse_real_valued = feature_column._real_valued_var_len_column( @@ -640,6 +649,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval() - 2) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnWithMultiDimensionsAndNormalizer(self): real_valued = feature_column.real_valued_column( @@ -651,6 +664,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval() - 2) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testBucketizedColumnWithNormalizerSucceedsForDNN(self): bucket = feature_column.bucketized_column( From a83997223feac44e9a94c87c7c13e2452a658b73 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 09:06:58 -0800 Subject: [PATCH 13/43] Update build rules for contrib/opt. Change: 155092161 --- tensorflow/contrib/opt/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 2173e13b91f..1843b6968ea 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -29,7 +29,9 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:state_ops", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variables", + "//third_party/py/scipy", "@six_archive//:six", ], ) From 0cd6405a02fbac7b84679f2a631408dc641c6d4a Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 4 May 2017 09:22:18 -0800 Subject: [PATCH 14/43] RNNCells' trainable_weights and non_trainable_weights parameters return valid values. Change: 155094112 --- .../python/kernel_tests/core_rnn_cell_test.py | 63 +++++++++++++++---- .../rnn/python/ops/core_rnn_cell_impl.py | 48 -------------- tensorflow/contrib/rnn/python/ops/rnn_cell.py | 3 - tensorflow/python/ops/rnn_cell_impl.py | 40 ++++++++++-- 4 files changed, 85 insertions(+), 69 deletions(-) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 15afac98237..f4589e3d9e1 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -74,7 +74,41 @@ class RNNCellTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 2]) - g, _ = core_rnn_cell_impl.BasicRNNCell(2)(x, m) + cell = core_rnn_cell_impl.BasicRNNCell(2) + g, _ = cell(x, m) + self.assertEqual( + ["root/basic_rnn_cell/%s:0" + % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/basic_rnn_cell/%s:0" + % core_rnn_cell_impl._BIAS_VARIABLE_NAME], + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g], {x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]])}) + self.assertEqual(res[0].shape, (1, 2)) + + def testBasicRNNCellNotTrainable(self): + with self.test_session() as sess: + def not_trainable_getter(getter, *args, **kwargs): + kwargs["trainable"] = False + return getter(*args, **kwargs) + + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5), + custom_getter=not_trainable_getter): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = core_rnn_cell_impl.BasicRNNCell(2) + g, _ = cell(x, m) + self.assertFalse(cell.trainable_variables) + self.assertEqual( + ["root/basic_rnn_cell/%s:0" + % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/basic_rnn_cell/%s:0" + % core_rnn_cell_impl._BIAS_VARIABLE_NAME], + [v.name for v in cell.non_trainable_variables]) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( [g], {x.name: np.array([[1., 1.]]), @@ -114,10 +148,23 @@ class RNNCellTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 8]) - g, out_m = core_rnn_cell_impl.MultiRNNCell( + cell = core_rnn_cell_impl.MultiRNNCell( [core_rnn_cell_impl.BasicLSTMCell( 2, state_is_tuple=False) for _ in range(2)], - state_is_tuple=False)(x, m) + state_is_tuple=False) + g, out_m = cell(x, m) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" + % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" + % core_rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" + % core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" + % core_rnn_cell_impl._BIAS_VARIABLE_NAME] + self.assertEqual( + expected_variable_names, [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( [g, out_m], @@ -125,15 +172,7 @@ class RNNCellTest(test.TestCase): m.name: 0.1 * np.ones([1, 8])}) self.assertEqual(len(res), 2) variables = variables_lib.global_variables() - self.assertEqual(4, len(variables)) - self.assertEquals(variables[0].op.name, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/weights") - self.assertEquals(variables[1].op.name, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/biases") - self.assertEquals(variables[2].op.name, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/weights") - self.assertEquals(variables[3].op.name, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/biases") + self.assertEqual(expected_variable_names, [v.name for v in variables]) # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) expected_mem = np.array([[ diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index 884b51926eb..eba2c0d2acb 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -27,7 +27,6 @@ from __future__ import division from __future__ import print_function import collections -import contextlib import hashlib import math import numbers @@ -57,53 +56,6 @@ _BIAS_VARIABLE_NAME = "biases" _WEIGHTS_VARIABLE_NAME = "weights" -@contextlib.contextmanager -def _checked_scope(cell, scope, reuse=None, **kwargs): - if reuse is not None: - kwargs["reuse"] = reuse - with vs.variable_scope(scope, **kwargs) as checking_scope: - scope_name = checking_scope.name - if hasattr(cell, "_scope"): - cell_scope = cell._scope # pylint: disable=protected-access - if cell_scope.name != checking_scope.name: - raise ValueError( - "Attempt to reuse RNNCell %s with a different variable scope than " - "its first use. First use of cell was with scope '%s', this " - "attempt is with scope '%s'. Please create a new instance of the " - "cell if you would like it to use a different set of weights. " - "If before you were using: MultiRNNCell([%s(...)] * num_layers), " - "change to: MultiRNNCell([%s(...) for _ in range(num_layers)]). " - "If before you were using the same cell instance as both the " - "forward and reverse cell of a bidirectional RNN, simply create " - "two instances (one for forward, one for reverse). " - "In May 2017, we will start transitioning this cell's behavior " - "to use existing stored weights, if any, when it is called " - "with scope=None (which can lead to silent model degradation, so " - "this error will remain until then.)" - % (cell, cell_scope.name, scope_name, type(cell).__name__, - type(cell).__name__)) - else: - weights_found = False - try: - with vs.variable_scope(checking_scope, reuse=True): - vs.get_variable(_WEIGHTS_VARIABLE_NAME) - weights_found = True - except ValueError: - pass - if weights_found and reuse is None: - raise ValueError( - "Attempt to have a second RNNCell use the weights of a variable " - "scope that already has weights: '%s'; and the cell was not " - "constructed as %s(..., reuse=True). " - "To share the weights of an RNNCell, simply " - "reuse it in your second calculation, or create a new one with " - "the argument reuse=True." % (scope_name, type(cell).__name__)) - - # Everything is OK. Update the cell's scope and yield it. - cell._scope = checking_scope # pylint: disable=protected-access - yield checking_scope - - class BasicRNNCell(RNNCell): """The most basic RNN cell.""" diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index acba77f0e13..ad23e532b10 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -39,9 +39,6 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -_checked_scope = core_rnn_cell_impl._checked_scope # pylint: disable=protected-access - - def _get_concat_variable(name, shape, dtype, num_shards): """Get a sharded variable concatenated into one tensor.""" sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 4810e97b367..c7ac742b5d9 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -28,6 +28,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as tf_variables from tensorflow.python.util import nest @@ -75,11 +77,13 @@ def _zero_state_tensors(state_size, batch_size, dtype): return zeros -class _RNNCell(base_layer.Layer): # pylint: disable=protected-access +class _RNNCell(base_layer.Layer): """Abstract object representing an RNN cell. - Every `RNNCell` must have the properties below and implement `__call__` with - the following signature. + Every `RNNCell` must have the properties below and implement `call` with + the signature `(output, next_state) = call(input, state)`. The optional + third input argument, `scope`, is allowed for backwards compatibility + purposes; but should be left off for new subclasses. This definition of cell differs from the definition used in the literature. In the literature, 'cell' refers to an object with a single scalar output. @@ -90,8 +94,9 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access This operation results in an output matrix with `self.output_size` columns. If `self.state_size` is an integer, this operation also results in a new state matrix with `self.state_size` columns. If `self.state_size` is a - tuple of integers, then it results in a tuple of `len(state_size)` state - matrices, each with a column size corresponding to values in `state_size`. + (possibly nested tuple of) TensorShape object(s), then it should return a + matching structure of Tensors having shape `[batch_size].concatenate(s)` + for each `s` in `self.batch_size`. """ def __call__(self, inputs, state, scope=None): @@ -112,7 +117,25 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access - New state: Either a single `2-D` tensor, or a tuple of tensors matching the arity and shapes of `state`. """ - return super(_RNNCell, self).__call__(inputs, state, scope=scope) + if scope is not None: + with vs.variable_scope(scope, + custom_getter=self._rnn_get_variable) as scope: + return super(_RNNCell, self).__call__(inputs, state, scope=scope) + else: + with vs.variable_scope(vs.get_variable_scope(), + custom_getter=self._rnn_get_variable): + return super(_RNNCell, self).__call__(inputs, state) + + def _rnn_get_variable(self, getter, *args, **kwargs): + variable = getter(*args, **kwargs) + trainable = (variable in tf_variables.trainable_variables() or + (isinstance(variable, tf_variables.PartitionedVariable) and + list(variable)[0] in tf_variables.trainable_variables())) + if trainable and variable not in self._trainable_weights: + self._trainable_weights.append(variable) + elif not trainable and variable not in self._non_trainable_weights: + self._non_trainable_weights.append(variable) + return variable @property def state_size(self): @@ -128,6 +151,11 @@ class _RNNCell(base_layer.Layer): # pylint: disable=protected-access """Integer or TensorShape: size of outputs produced by this cell.""" raise NotImplementedError("Abstract method") + def build(self, _): + # This tells the parent Layer object that it's OK to call + # self.add_variable() inside the call() method. + pass + def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). From 94e7325bbf2d07e0e0044f37c6249b4b3173d9b4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 09:22:47 -0800 Subject: [PATCH 15/43] Making the pip smoke test warning clearer. Change: 155094164 --- tensorflow/tools/pip_package/pip_smoke_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 0438ce68469..fa61a19b39f 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -122,7 +122,10 @@ def main(): affected_tests_list = affected_tests.split("\n")[:-2] print("\n".join(affected_tests_list)) - raise RuntimeError("One or more dependencies are not in the pip package.") + raise RuntimeError("""One or more dependencies are not in the pip package. +Please either blacklist the dependencies in +tensorflow/tensorflow/tensorflow/tools/pip_package/pip_smoke_test.py +or add them to tensorflow/tensorflow/tensorflow/tools/pip_package/BUILD.""") else: print("TEST PASSED") From 61048d872698cbfb5462e850823c97c7f733b35d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 09:42:53 -0800 Subject: [PATCH 16/43] Automated rollback of change 154868460 Change: 155096835 --- tensorflow/core/kernels/crop_and_resize_op.cc | 543 ++++++++---------- tensorflow/core/kernels/crop_and_resize_op.h | 8 +- .../core/kernels/crop_and_resize_op_gpu.cu.cc | 2 +- .../core/kernels/crop_and_resize_op_test.cc | 6 +- 4 files changed, 240 insertions(+), 319 deletions(-) diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 1c7afcf8663..746fe63e2a0 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -19,9 +19,6 @@ limitations under the License. #include "tensorflow/core/kernels/crop_and_resize_op.h" -#include -#include - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -29,13 +26,10 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -43,67 +37,41 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -using Callback = std::function; -namespace { - -static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, - const Tensor& box_index, - int* num_boxes) { - if (boxes.NumElements() == 0 && box_index.NumElements() == 0) { +static inline void ParseAndCheckBoxSizes(OpKernelContext* context, + const Tensor& boxes, + const Tensor& box_ind, + int* num_boxes) { + if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) { *num_boxes = 0; - return Status::OK(); + return; } // The shape of 'boxes' is [num_boxes, 4]. - if (boxes.dims() != 2) { - return errors::InvalidArgument("boxes must be 2-D", - boxes.shape().DebugString()); - } + OP_REQUIRES(context, boxes.dims() == 2, + errors::InvalidArgument("boxes must be 2-D", + boxes.shape().DebugString())); *num_boxes = boxes.dim_size(0); - if (boxes.dim_size(1) != 4) { - return errors::InvalidArgument("boxes must have 4 columns"); - } - // The shape of 'box_index' is [num_boxes]. - if (box_index.dims() != 1) { - return errors::InvalidArgument("box_index must be 1-D", - box_index.shape().DebugString()); - } - if (box_index.dim_size(0) != *num_boxes) { - return errors::InvalidArgument("box_index has incompatible shape"); - } - return Status::OK(); + OP_REQUIRES(context, boxes.dim_size(1) == 4, + errors::InvalidArgument("boxes must have 4 columns")); + + // The shape of 'box_ind' is [num_boxes]. + OP_REQUIRES(context, box_ind.dims() == 1, + errors::InvalidArgument("box_ind must be 1-D", + box_ind.shape().DebugString())); + OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes, + errors::InvalidArgument("box_ind has incompatible shape")); } -// Conditionally calls the compute callback if all values in box_index are in -// [0, batch_size) then calls done. +// Verifies that all values in box_ind are in [0, batch). template -inline void RunIfBoxIndexIsValid( - OpKernelContext* context, typename TTypes::ConstTensor box_index, - int batch_size, Callback compute, Callback done); - -// Specialization of CheckValidBoxIndex for a CPUDevice. -template <> -inline void RunIfBoxIndexIsValid( - OpKernelContext* context, typename TTypes::ConstTensor box_index, - int batch_size, Callback compute, Callback done) { - const int num_boxes = box_index.dimension(0); - for (int b = 0; b < num_boxes; ++b) { - OP_REQUIRES_ASYNC( - context, FastBoundsCheck(box_index(b), batch_size), - errors::OutOfRange("box_index has values outside [0, batch_size)"), - done); - } - compute(); - done(); -} - -} // namespace +inline void CheckValidBoxInd( + OpKernelContext* context, + typename TTypes::ConstTensor box_ind_data, int batch); template -class CropAndResizeOp : public AsyncOpKernel { +class CropAndResizeOp : public OpKernel { public: - explicit CropAndResizeOp(OpKernelConstruction* context) - : AsyncOpKernel(context) { + explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", @@ -112,77 +80,69 @@ class CropAndResizeOp : public AsyncOpKernel { &extrapolation_value_)); } - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { - // The shape of 'image' is [batch_size, image_height, image_width, - // channels]. + void Compute(OpKernelContext* context) override { + // The shape of 'image' is [batch, image_height, image_width, channels]. const Tensor& image = context->input(0); - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(1); - // The shape of 'box_index' is [num_boxes]. - const Tensor& box_index = context->input(2); - // The shape of 'crop_size' is [2]. - const Tensor& crop_size = context->input(3); + OP_REQUIRES(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString())); - // Validate inputs dimensions. - OP_REQUIRES_ASYNC(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString()), - done); - const int batch_size = image.dim_size(0); + const int batch = image.dim_size(0); const int image_height = image.dim_size(1); const int image_width = image.dim_size(2); const int depth = image.dim_size(3); - OP_REQUIRES_ASYNC( - context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive"), done); + OP_REQUIRES(context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive")); + + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(1); + + // The shape of 'box_ind' is [num_boxes]. + const Tensor& box_ind = context->input(2); + int num_boxes = 0; - OP_REQUIRES_OK_ASYNC( - context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - OP_REQUIRES_ASYNC(context, crop_size.dims() == 1, - errors::InvalidArgument("crop_size must be 1-D", - crop_size.shape().DebugString()), - done); - OP_REQUIRES_ASYNC( - context, crop_size.dim_size(0) == 2, - errors::InvalidArgument("crop_size must have two elements", - crop_size.shape().DebugString()), - done); + // The shape of 'crop_size' is [2]. + const Tensor& crop_size = context->input(3); + + OP_REQUIRES(context, crop_size.dims() == 1, + errors::InvalidArgument("crop_size must be 1-D", + crop_size.shape().DebugString())); + OP_REQUIRES(context, crop_size.dim_size(0) == 2, + errors::InvalidArgument("crop_size must have two elements", + crop_size.shape().DebugString())); - // Copy and validate crop sizes. auto crop_size_vec = crop_size.vec(); const int crop_height = internal::SubtleMustCopy(crop_size_vec(0)); const int crop_width = internal::SubtleMustCopy(crop_size_vec(1)); - OP_REQUIRES_ASYNC( - context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("crop dimensions must be positive"), done); + OP_REQUIRES(context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("crop dimensions must be positive")); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( + OP_REQUIRES_OK( context, context->allocate_output( 0, TensorShape({num_boxes, crop_height, crop_width, depth}), - &output), - done); + &output)); - auto compute_callback = [this, context, output]() { - const Tensor& image = context->input(0); - const Tensor& boxes = context->input(1); - const Tensor& box_index = context->input(2); - const bool status = functor::CropAndResize()( - context->eigen_device(), image.tensor(), - boxes.tensor(), box_index.tensor(), - extrapolation_value_, output->tensor()); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeKernel.")); - } - }; + typename TTypes::ConstTensor image_data = image.tensor(); + typename TTypes::ConstTensor boxes_data = + boxes.tensor(); + typename TTypes::ConstTensor box_ind_data = + box_ind.tensor(); + typename TTypes::Tensor crops_data = output->tensor(); - RunIfBoxIndexIsValid(context, box_index.tensor(), - batch_size, std::move(compute_callback), - std::move(done)); + CheckValidBoxInd(context, box_ind_data, batch); + + bool status = functor::CropAndResize()( + context->eigen_device(), image_data, boxes_data, box_ind_data, + extrapolation_value_, crops_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeKernel.")); + } } private: @@ -195,10 +155,10 @@ template struct CropAndResize { bool operator()(const CPUDevice& d, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_index, + typename TTypes::ConstTensor box_ind, float extrapolation_value, typename TTypes::Tensor crops) { - const int batch_size = image.dimension(0); + const int batch = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -213,8 +173,8 @@ struct CropAndResize { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_index(b); - if (!FastBoundsCheck(b_in, batch_size)) { + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { continue; } @@ -275,94 +235,89 @@ struct CropAndResize { return true; } }; - } // namespace functor template -class CropAndResizeGradImageOp : public AsyncOpKernel { +class CropAndResizeGradImageOp : public OpKernel { public: explicit CropAndResizeGradImageOp(OpKernelConstruction* context) - : AsyncOpKernel(context) { + : OpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + void Compute(OpKernelContext* context) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(1); - // The shape of 'box_index' is [num_boxes]. - const Tensor& box_index = context->input(2); - // The shape of 'image_size' is [4]. - const Tensor& image_size = context->input(3); - // Validate input shapes. - OP_REQUIRES_ASYNC(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString()), - done); + OP_REQUIRES(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString())); const int crop_height = grads.dim_size(1); const int crop_width = grads.dim_size(2); - OP_REQUIRES_ASYNC( - context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive"), done); - int num_boxes = 0; - OP_REQUIRES_OK_ASYNC( - context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); - OP_REQUIRES_ASYNC( - context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape"), - done); + OP_REQUIRES(context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive")); + + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(1); + + // The shape of 'box_ind' is [num_boxes]. + const Tensor& box_ind = context->input(2); + + int num_boxes = 0; + ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); + + OP_REQUIRES( + context, grads.dim_size(0) == num_boxes, + errors::InvalidArgument("boxes and grads have incompatible shape")); + + // The shape of 'image_size' is [4]. + const Tensor& image_size = context->input(3); + OP_REQUIRES(context, image_size.dims() == 1, + errors::InvalidArgument("image_size must be 1-D", + image_size.shape().DebugString())); + OP_REQUIRES(context, image_size.dim_size(0) == 4, + errors::InvalidArgument("image_size must have 4 elements", + image_size.shape().DebugString())); - OP_REQUIRES_ASYNC(context, image_size.dims() == 1, - errors::InvalidArgument("image_size must be 1-D", - image_size.shape().DebugString()), - done); - OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4, - errors::InvalidArgument("image_size must have 4 elements", - image_size.shape().DebugString()), - done); auto image_size_vec = image_size.vec(); - const int batch_size = internal::SubtleMustCopy(image_size_vec(0)); + const int batch = internal::SubtleMustCopy(image_size_vec(0)); const int image_height = internal::SubtleMustCopy(image_size_vec(1)); const int image_width = internal::SubtleMustCopy(image_size_vec(2)); const int depth = internal::SubtleMustCopy(image_size_vec(3)); - OP_REQUIRES_ASYNC( - context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive"), done); - OP_REQUIRES_ASYNC( + + OP_REQUIRES(context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive")); + OP_REQUIRES( context, grads.dim_size(3) == depth, - errors::InvalidArgument("image_size and grads are incompatible"), done); + errors::InvalidArgument("image_size and grads are incompatible")); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_output( - 0, TensorShape({batch_size, image_height, image_width, depth}), - &output), - done); + OP_REQUIRES_OK( + context, context->allocate_output( + 0, TensorShape({batch, image_height, image_width, depth}), + &output)); - auto compute_callback = [context, output]() { - const Tensor& grads = context->input(0); - const Tensor& boxes = context->input(1); - const Tensor& box_index = context->input(2); - const bool status = functor::CropAndResizeBackpropImage()( - context->eigen_device(), grads.tensor(), - boxes.tensor(), box_index.tensor(), - output->tensor()); - if (!status) { - context->SetStatus(errors::Internal( - "Failed launch CropAndResizeBackpropImage kernel.")); - } - }; + typename TTypes::ConstTensor grads_data = + grads.tensor(); + typename TTypes::ConstTensor boxes_data = + boxes.tensor(); + typename TTypes::ConstTensor box_ind_data = + box_ind.tensor(); + typename TTypes::Tensor output_data = output->tensor(); - RunIfBoxIndexIsValid(context, box_index.tensor(), - batch_size, std::move(compute_callback), - std::move(done)); + CheckValidBoxInd(context, box_ind_data, batch); + + bool status = functor::CropAndResizeBackpropImage()( + context->eigen_device(), grads_data, boxes_data, box_ind_data, + output_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeBackpropImageKernel.")); + } } }; @@ -373,9 +328,9 @@ struct CropAndResizeBackpropImage { bool operator()(const CPUDevice& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_index, + typename TTypes::ConstTensor box_ind, typename TTypes::Tensor grads_image) { - const int batch_size = grads_image.dimension(0); + const int batch = grads_image.dimension(0); const int image_height = grads_image.dimension(1); const int image_width = grads_image.dimension(2); @@ -392,8 +347,8 @@ struct CropAndResizeBackpropImage { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_index(b); - if (!FastBoundsCheck(b_in, batch_size)) { + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { continue; } @@ -444,90 +399,83 @@ struct CropAndResizeBackpropImage { return true; } }; - } // namespace functor template -class CropAndResizeGradBoxesOp : public AsyncOpKernel { +class CropAndResizeGradBoxesOp : public OpKernel { public: explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context) - : AsyncOpKernel(context) { + : OpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + void Compute(OpKernelContext* context) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(2); - // The shape of 'box_index' is [num_boxes]. - const Tensor& box_index = context->input(3); - // The shape of 'image' is [batch_size, image_height, image_width, depth]. - const Tensor& image = context->input(1); - // Validate input shapes. - OP_REQUIRES_ASYNC(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString()), - done); + OP_REQUIRES(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString())); + const int crop_height = grads.dim_size(1); const int crop_width = grads.dim_size(2); const int depth = grads.dim_size(3); - OP_REQUIRES_ASYNC( - context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive"), done); + OP_REQUIRES(context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive")); - OP_REQUIRES_ASYNC(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString()), - done); - const int batch_size = image.dim_size(0); + // The shape of 'image' is [batch, image_height, image_width, depth]. + const Tensor& image = context->input(1); + OP_REQUIRES(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString())); + + const int batch = image.dim_size(0); const int image_height = image.dim_size(1); const int image_width = image.dim_size(2); - OP_REQUIRES_ASYNC( - context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive"), done); - OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth, - errors::InvalidArgument("image, grads depth differ"), - done); + OP_REQUIRES(context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive")); + OP_REQUIRES(context, image.dim_size(3) == depth, + errors::InvalidArgument("image, grads depth differ")); + + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(2); + + // The shape of 'box_ind' is [num_boxes]. + const Tensor& box_ind = context->input(3); int num_boxes = 0; - OP_REQUIRES_OK_ASYNC( - context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - OP_REQUIRES_ASYNC( + OP_REQUIRES( context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape"), - done); + errors::InvalidArgument("boxes and grads have incompatible shape")); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_output(0, TensorShape({num_boxes, 4}), &output), - done); + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({num_boxes, 4}), &output)); - auto compute_callback = [context, output]() { - const Tensor& grads = context->input(0); - const Tensor& image = context->input(1); - const Tensor& boxes = context->input(2); - const Tensor& box_index = context->input(3); - const bool status = functor::CropAndResizeBackpropBoxes()( - context->eigen_device(), grads.tensor(), - image.tensor(), boxes.tensor(), - box_index.tensor(), output->tensor()); - if (!status) { - context->SetStatus(errors::Internal( - "Failed launch CropAndResizeBackpropBoxes kernel.")); - } - }; + typename TTypes::ConstTensor grads_data = + grads.tensor(); + typename TTypes::ConstTensor image_data = image.tensor(); + typename TTypes::ConstTensor boxes_data = + boxes.tensor(); + typename TTypes::ConstTensor box_ind_data = + box_ind.tensor(); + typename TTypes::Tensor output_data = output->tensor(); - RunIfBoxIndexIsValid(context, box_index.tensor(), - batch_size, std::move(compute_callback), - std::move(done)); + CheckValidBoxInd(context, box_ind_data, batch); + + bool status = functor::CropAndResizeBackpropBoxes()( + context->eigen_device(), grads_data, image_data, boxes_data, + box_ind_data, output_data); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel.")); + } } }; @@ -539,9 +487,9 @@ struct CropAndResizeBackpropBoxes { typename TTypes::ConstTensor grads, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_index, + typename TTypes::ConstTensor box_ind, typename TTypes::Tensor grads_boxes) { - const int batch_size = image.dimension(0); + const int batch = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -558,8 +506,8 @@ struct CropAndResizeBackpropBoxes { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_index(b); - if (!FastBoundsCheck(b_in, batch_size)) { + const int32 b_in = box_ind(b); + if (b_in < 0 || b_in >= batch) { continue; } @@ -641,19 +589,30 @@ struct CropAndResizeBackpropBoxes { return true; } }; - } // namespace functor -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("crop_size"), \ - CropAndResizeOp); \ - \ - REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +// Specialization of CheckValidBoxInd for a CPUDevice. +template <> +inline void CheckValidBoxInd( + OpKernelContext* context, typename TTypes::ConstTensor box_ind, + int batch) { + const int num_boxes = box_ind.dimension(0); + for (int b = 0; b < num_boxes; ++b) { + OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch, + errors::OutOfRange("box_ind has values outside [0, batch)")); + } +} + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("crop_size"), \ + CropAndResizeOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ CropAndResizeGradBoxesOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); @@ -675,86 +634,50 @@ TF_CALL_double(REGISTER_KERNEL); #if GOOGLE_CUDA -// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU. +// Forward declaration of the CheckValidBoxIndHelper specialization for GPU. namespace functor { template <> -void CheckValidBoxIndexHelper::operator()( - const GPUDevice& d, typename TTypes::ConstTensor box_index, - int batch_size, typename TTypes::Tensor isvalid); -extern template struct CheckValidBoxIndexHelper; +void CheckValidBoxIndHelper::operator()( + const GPUDevice& d, typename TTypes::ConstTensor box_ind, + int batch, typename TTypes::Tensor isvalid); +extern template struct CheckValidBoxIndHelper; } // namespace functor -namespace { - -// Specialization of CheckValidBoxIndex for a GPUDevice. +// Specialization of CheckValidBoxInd for a GPUDevice. template <> -inline void RunIfBoxIndexIsValid( - OpKernelContext* context, typename TTypes::ConstTensor box_index, - int batch_size, Callback compute, Callback done) { - const int num_boxes = box_index.dimension(0); +inline void CheckValidBoxInd( + OpKernelContext* context, typename TTypes::ConstTensor box_ind, + int batch) { + const int num_boxes = box_ind.dimension(0); if (num_boxes == 0) { - compute(); - done(); return; } + Tensor isvalid_tensor; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({}), &isvalid_tensor)); - Tensor isvalid_dev_tensor; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_temp(DataTypeToEnum::value, TensorShape({}), - &isvalid_dev_tensor), - done); - typename TTypes::Tensor isvalid_dev = - isvalid_dev_tensor.tensor(); + typename TTypes::Tensor isvalid = isvalid_tensor.tensor(); - // Run the actual box check on the device. - functor::CheckValidBoxIndexHelper()( - context->eigen_device(), box_index, batch_size, isvalid_dev); + functor::CheckValidBoxIndHelper()( + context->eigen_device(), box_ind, batch, isvalid); - // Copy the result back to the host. auto* stream = context->op_device_context()->stream(); - OP_REQUIRES_ASYNC(context, stream, - errors::Internal("No GPU stream available."), done); - Tensor isvalid_host_tensor; - // Use pinned host memory on the host to avoid unnecessary - // synchronization. - AllocatorAttributes alloc_attr; - alloc_attr.set_on_host(true); - alloc_attr.set_gpu_compatible(true); - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_temp(DataTypeToEnum::value, TensorShape({}), - &isvalid_host_tensor, alloc_attr), - done); - typename TTypes::Tensor isvalid_host = - isvalid_host_tensor.tensor(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(), - sizeof(bool)); - const bool status = stream - ->ThenMemcpy(isvalid_host.data() /* destination */, - wrapped /* source */, sizeof(bool)) - .ok(); - OP_REQUIRES_ASYNC( - context, status, - errors::Internal("Failed to launch copy of isvalid from device to host."), - done); + bool isvalid_host = false; + perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(), + sizeof(bool)); + stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool)); + stream->BlockHostUntilDone(); - auto wrapped_callback = [context, isvalid_host, compute, done]() { - OP_REQUIRES_ASYNC( - context, isvalid_host(), - errors::OutOfRange("box_index has values outside [0, batch_size)"), - done); - compute(); - done(); - }; + OP_REQUIRES(context, stream->ok(), + errors::Internal("cudaMemcpy from device to host failed")); - context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( - stream, wrapped_callback); + OP_REQUIRES(context, isvalid_host, + errors::OutOfRange("box_ind has values outside [0, batch)")); } -} // namespace - #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ .Device(DEVICE_GPU) \ diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h index 460dbad22b4..22df1bdd56b 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.h +++ b/tensorflow/core/kernels/crop_and_resize_op.h @@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes { }; template -struct CheckValidBoxIndexHelper { - // Checks if all values in box_index are in [0, batch). +struct CheckValidBoxIndHelper { + // Checks if all values in box_ind are in [0, batch). void operator()(const Device& d, - typename TTypes::ConstTensor box_index, int batch, + typename TTypes::ConstTensor box_ind, int batch, typename TTypes::Tensor isvalid) { - isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all(); + isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all(); } }; diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc index c1235fda892..254475db465 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc @@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); #undef DEFINE_GPU_SPECS -template struct CheckValidBoxIndexHelper; +template struct CheckValidBoxIndHelper; } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc index d6139dae966..3a7f180598e 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_test.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc @@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( - StringPiece(s.ToString()).contains("box_index has incompatible shape")) + StringPiece(s.ToString()).contains("box_ind has incompatible shape")) << s; } @@ -264,10 +264,8 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(StringPiece(s.ToString()) - .contains("box_index has values outside [0, batch_size)")) + .contains("box_ind has values outside [0, batch)")) << s; } -// TODO(zhengxq, rmlarsen): Add a benchmark. - } // namespace tensorflow From 2b75d151c92660e9aa8d378feda20059a1592e5c Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Thu, 4 May 2017 09:55:12 -0800 Subject: [PATCH 17/43] Fix an error. Change: 155098574 --- tensorflow/core/grappler/optimizers/layout_optimizer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 5fec89b6987..e37c4a5b36a 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -405,7 +405,7 @@ class Conv2DProcessor : public NodeProcessor { } } if (input_shape.dim_size() == 4 && filter_shape.dim_size() == 4) { - if (input_shape.dim(1).size() == filter_shape.dim(0).size() == 1 && + if (input_shape.dim(1).size() == filter_shape.dim(0).size() && input_shape.dim(2).size() == filter_shape.dim(1).size() && IsValidPadding()) { return true; From 1f3e689b98365556741ccaba71868ffe1ab6364d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 4 May 2017 10:20:38 -0800 Subject: [PATCH 18/43] Don't repeatedly serialize the function library in MasterSession. Change: 155102455 --- tensorflow/core/distributed_runtime/master_session.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index f7b422b70e3..5257aea1e3a 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -162,7 +162,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { // Partitions the graph into subgraphs and registers them on // workers. Status RegisterPartitions(const PartitionOptions& popts, - const FunctionDefLibrary& func_def_lib); + const FunctionLibraryDefinition& flib_def); // Runs one step of all partitions. Status RunPartitions(const MasterEnv* env, int64 step_id, @@ -273,7 +273,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { }; Status MasterSession::ReffedClientGraph::RegisterPartitions( - const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib) { + const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) { { // Ensure register once. mu_.lock(); if (!init_started_) { @@ -292,7 +292,8 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions( graph_defs_for_publishing.push_back(&name_def.second); } stats_publisher_->PublishGraphProto(graph_defs_for_publishing); - s = DoRegisterPartitions(popts, func_def_lib, std::move(graph_defs)); + s = DoRegisterPartitions(popts, flib_def.ToProto(), + std::move(graph_defs)); } mu_.lock(); init_result_ = s; @@ -1214,7 +1215,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { } TF_RETURN_IF_ERROR( - rcg->RegisterPartitions(popts, rcg->client_graph()->flib_def->ToProto())); + rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def)); return Status::OK(); } From 50b836addfed6b49fc823987e9301f1b6eeef90c Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Thu, 4 May 2017 11:12:52 -0800 Subject: [PATCH 19/43] Github #9633: enforce an omitted shape check in sparse_add() op. Previously, the code made the assumption that the two operands have matching shapes, but did not enforce the equality. This could lead to invalid memory access in some cases. Change: 155109464 --- .../kernels/sparse_tensor_dense_add_op.cc | 27 +++++++++++++------ tensorflow/python/ops/sparse_ops.py | 2 ++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc index b5093d59fc0..48f38872e25 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc @@ -47,16 +47,26 @@ class SparseTensorDenseAddOp : public OpKernel { "Input a_indices should be a matrix but received shape: ", a_indices_t->shape().DebugString())); OP_REQUIRES( - ctx, TensorShapeUtils::IsVector(a_values_t->shape()) && - TensorShapeUtils::IsVector(a_shape_t->shape()), + ctx, + TensorShapeUtils::IsVector(a_values_t->shape()) && + TensorShapeUtils::IsVector(a_shape_t->shape()), errors::InvalidArgument("Inputs a_values and a_shape should be vectors " "but received shapes: ", a_values_t->shape().DebugString(), " and ", a_shape_t->shape().DebugString())); - OP_REQUIRES(ctx, a_shape_t->NumElements() == b->dims(), - errors::InvalidArgument( - "Two operands have different dimensions; received: ", - a_shape_t->NumElements(), " and ", b->dims())); + OP_REQUIRES( + ctx, a_shape_t->NumElements() == b->dims(), + errors::InvalidArgument("Two operands have different ranks; received: ", + a_shape_t->NumElements(), " and ", b->dims())); + const auto a_shape_flat = a_shape_t->flat(); + for (int i = 0; i < b->dims(); ++i) { + OP_REQUIRES( + ctx, a_shape_flat(i) == b->dim_size(i), + errors::InvalidArgument( + "Dimension ", i, + " does not equal (no broadcasting is supported): sparse side ", + a_shape_flat(i), " vs dense side ", b->dim_size(i))); + } Tensor *out_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, b->shape(), &out_t)); @@ -82,8 +92,9 @@ class SparseTensorDenseAddOp : public OpKernel { NDIMS_CASE(4); NDIMS_CASE(5); default: - OP_REQUIRES(ctx, false, errors::InvalidArgument( - "Only tensors with ranks between 1 and 5 " + OP_REQUIRES( + ctx, false, + errors::InvalidArgument("Only tensors with ranks between 1 and 5 " "are currently supported. Tensor rank: ", ndims)); #undef NDIMS_CASE diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 0140a27aaa7..d6cb7c5be49 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -241,6 +241,8 @@ def sparse_add(a, b, thresh=0): of arguments does not matter. Use vanilla `tf.add()` for adding two dense `Tensor`s. + The shapes of the two operands must match: broadcasting is not supported. + The indices of any input `SparseTensor` are assumed ordered in standard lexicographic order. If this is not the case, before this step run `SparseReorder` to restore index ordering. From f6f26abe3416cc63f43d3ed18cb98eefae966e24 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 4 May 2017 11:58:52 -0800 Subject: [PATCH 20/43] Don't declare the ReadyNodeManager class in an anonymous namespace since some compilers don't like that. Change: 155115034 --- tensorflow/core/grappler/costs/virtual_scheduler.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index b7785c94e04..5d437dff50e 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -26,7 +26,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace { struct NodeState { std::vector inputs; std::vector outputs; @@ -86,7 +85,6 @@ class FIFOManager : public ReadyNodeManager { private: std::list nodes_; }; -} // namespace // The virtual scheduler emulates execution of nodes in a graph, considering // dependencies, device, etc. From e847d1b41fc44d9e892c8fef446b94a709dc8cf3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 4 May 2017 11:59:45 -0800 Subject: [PATCH 21/43] Automated rollback of change 155092161 Change: 155115134 --- tensorflow/contrib/opt/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 1843b6968ea..2173e13b91f 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -29,9 +29,7 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:state_ops", "//tensorflow/python:training", - "//tensorflow/python:util", "//tensorflow/python:variables", - "//third_party/py/scipy", "@six_archive//:six", ], ) From 28c03f5d5a946e3050f7eed0cd7f3f64e5fa768d Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Thu, 4 May 2017 12:19:59 -0800 Subject: [PATCH 22/43] Docs: Stop advertising GPU support for Mac OS X. Starting with the next release (1.2) GPU support for Mac will be dropped in release binaries since the configuration of NVIDIA GPUs on a Mac is somewhat esoteric (and we currently do not have the bandwidth to debug test failures on that platform). While at it, change to version 1.1.0 from 1.1.0-rc2 Change: 155117808 --- tensorflow/docs_src/install/install_c.md | 2 +- tensorflow/docs_src/install/install_go.md | 2 +- tensorflow/docs_src/install/install_java.md | 18 +++++++-------- tensorflow/go/README.md | 25 +++++---------------- 4 files changed, 16 insertions(+), 31 deletions(-) diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md index c1581efb4f3..c1c7b665460 100644 --- a/tensorflow/docs_src/install/install_c.md +++ b/tensorflow/docs_src/install/install_c.md @@ -35,7 +35,7 @@ enable TensorFlow for C: OS="linux" # Change to "darwin" for Mac OS TARGET_DIRECTORY="/usr/local" curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.1.0-rc2.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.1.0.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md index dd713e4786e..c9abaf2acaf 100644 --- a/tensorflow/docs_src/install/install_go.md +++ b/tensorflow/docs_src/install/install_go.md @@ -35,7 +35,7 @@ steps to install this library and enable TensorFlow for Go: TF_TYPE="cpu" # Change to "gpu" for GPU support TARGET_DIRECTORY='/usr/local' curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0-rc2.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 65cfe375d57..55d9c2c08f3 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -34,7 +34,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs: org.tensorflow tensorflow - 1.1.0-rc2 + 1.1.0 ``` @@ -63,7 +63,7 @@ As an example, these steps will create a Maven project that uses TensorFlow: org.tensorflow tensorflow - 1.1.0-rc2 + 1.1.0 @@ -122,7 +122,7 @@ refer to the simpler instructions above instead. Take the following steps to install TensorFlow for Java on Linux or Mac OS: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc2.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0.jar), which is the TensorFlow Java Archive (JAR). 2. Decide whether you will run TensorFlow for Java on CPU(s) only or with @@ -141,7 +141,7 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS: OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc2.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0.tar.gz" | tar -xz -C ./jni ### Install on Windows @@ -149,10 +149,10 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS: Take the following steps to install TensorFlow for Java on Windows: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc2.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0.jar), which is the TensorFlow Java Archive (JAR). 2. Download the following Java Native Interface (JNI) file appropriate for - [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc2.zip). + [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0.zip). 3. Extract this .zip file. @@ -200,7 +200,7 @@ must be part of your `classpath`. For example, you can include the downloaded `.jar` in your `classpath` by using the `-cp` compilation flag as follows: -
javac -cp libtensorflow-1.1.0-rc2.jar HelloTF.java
+
javac -cp libtensorflow-1.1.0.jar HelloTF.java
### Running @@ -214,11 +214,11 @@ two files are available to the JVM: For example, the following command line executes the `HelloTF` program on Linux and Mac OS X: -
java -cp libtensorflow-1.1.0-rc2.jar:. -Djava.library.path=./jni HelloTF
+
java -cp libtensorflow-1.1.0.jar:. -Djava.library.path=./jni HelloTF
And the following comand line executes the `HelloTF` program on Windows: -
java -cp libtensorflow-1.1.0-rc2.jar;. -Djava.library.path=jni HelloTF
+
java -cp libtensorflow-1.1.0.jar;. -Djava.library.path=jni HelloTF
If the program prints Hello from version, you've successfully installed TensorFlow for Java and are ready to use the API. If the program diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md index e32c21ca720..a1b4255292b 100644 --- a/tensorflow/go/README.md +++ b/tensorflow/go/README.md @@ -9,24 +9,22 @@ Construct and execute TensorFlow graphs in Go. > (`github.com/tensorflow/tensorflow/tensorflow/go`). ## Quickstart - 1. Download and extract the TensorFlow C library, preferably into `/usr/local`. GPU-enabled versions require CUDA 8.0 and cuDNN 5.1. For other versions, the TensorFlow C library will have to be built from source (see below). - Linux: - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.0.0.tar.gz), - [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.0.0.tar.gz) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-1.1.0.tar.gz), + [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-1.1.0.tar.gz) - OS X - [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.0.0.tar.gz), - [GPU-enabled](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-darwin-x86_64-1.0.0.tar.gz) + [CPU-only](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-1.1.0.tar.gz), The following shell snippet downloads and extracts into `/usr/local`: ```sh TF_TYPE="cpu" # Set to "gpu" for GPU support curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.0.0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0.tar.gz" | sudo tar -C /usr/local -xz ``` @@ -41,20 +39,7 @@ Construct and execute TensorFlow graphs in Go. ### Installing into locations other than `/usr/local` -The TensorFlow C library (`libtensorflow.so`) needs to be available at build -time (e.g., `go build`) and run time (`go test` or executing binaries). If the -library has not been extracted into `/usr/local`, then it needs to be made -available through the `LIBRARY_PATH` environment variable at build time and the -`LD_LIBRARY_PATH` environment variable (`DYLD_LIBRARY_PATH` on OS X) at run -time. - -For example, if the TensorFlow C library was extracted into `/dir`, then: - -```sh -export LIBRARY_PATH=/dir/lib -export LD_LIBRARY_PATH=/dir/lib # For Linux -export DYLD_LIBRARY_PATH=/dir/lib # For OS X -``` +Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/install_go) ## Building the TensorFlow C library from source From e46a12bc9fbcea1fef224daa47eb9f1cf9e56472 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Thu, 4 May 2017 12:20:14 -0800 Subject: [PATCH 23/43] Docs: Fix broken link in contrib.layers.initializers Change: 155117831 --- tensorflow/contrib/layers/python/layers/initializers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index 4359d0c63e3..811e7fa7aa3 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -46,8 +46,7 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32): Args: uniform: Whether to use uniform or normal distributed random initialization. seed: A Python integer. Used to create random seeds. See - [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) - for behavior. + @{tf.set_random_seed} for behavior. dtype: The data type. Only floating point types are supported. Returns: @@ -97,8 +96,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'. uniform: Whether to use uniform or normal distributed random initialization. seed: A Python integer. Used to create random seeds. See - [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) - for behavior. + @{tf.set_random_seed} for behavior. dtype: The data type. Only floating point types are supported. Returns: From dd140f79e06a81c52cd8fc9ec6cda975a78a401f Mon Sep 17 00:00:00 2001 From: Yutaka Leon Date: Thu, 4 May 2017 12:31:30 -0800 Subject: [PATCH 24/43] Organize the lookup table ops into it's own lookup_ops.cc file instead of data_flow_ops.cc Change: 155119120 --- .../python/layers/feature_column_ops_test.py | 106 +-- .../estimators/dynamic_rnn_estimator_test.py | 6 +- .../python/learn/estimators/estimator.py | 4 +- .../python/learn/estimators/head_test.py | 10 +- .../state_saving_rnn_estimator_test.py | 6 +- .../learn/python/learn/graph_actions.py | 11 +- .../learn/python/learn/utils/export.py | 18 +- tensorflow/contrib/lookup/BUILD | 2 +- tensorflow/contrib/lookup/lookup_ops_test.py | 64 +- .../contrib/slim/python/slim/learning.py | 4 +- tensorflow/core/BUILD | 3 + tensorflow/core/kernels/BUILD | 10 +- tensorflow/core/ops/data_flow_ops.cc | 598 ---------------- tensorflow/core/ops/lookup_ops.cc | 666 ++++++++++++++++++ tensorflow/python/BUILD | 30 +- tensorflow/python/estimator/estimator_test.py | 9 +- tensorflow/python/feature_column/BUILD | 2 +- .../feature_column/feature_column_test.py | 8 +- .../python/feature_column/lookup_ops.py | 54 +- tensorflow/python/ops/data_flow_ops.py | 42 -- tensorflow/python/ops/lookup_ops.py | 77 ++ tensorflow/python/ops/standard_ops.py | 1 + tensorflow/python/saved_model/main_op_impl.py | 4 +- .../python/training/monitored_session.py | 4 +- .../python/training/saver_test_utils.py | 14 +- tensorflow/python/training/supervisor.py | 8 +- 26 files changed, 950 insertions(+), 811 deletions(-) create mode 100644 tensorflow/core/ops/lookup_ops.cc create mode 100644 tensorflow/python/ops/lookup_ops.py diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index b2dad0162e9..a09cc53571b 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -33,9 +33,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope @@ -224,7 +224,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(keys_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[keys_sparse].indices.eval(), @@ -242,7 +242,7 @@ class TransformerTest(test.TestCase): output = feature_column_ops._Transformer(features).transform(keys_sparse) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # While the input is a dense Tensor, the output should be a SparseTensor. self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertEqual(output.dtype, dtypes.int64) @@ -311,7 +311,7 @@ class TransformerTest(test.TestCase): self.assertIn(weighted_ids, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(), ids_tensor.dense_shape.eval()) self.assertAllEqual(output[weighted_ids][0].indices.eval(), @@ -341,7 +341,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -363,7 +363,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -387,7 +387,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -409,7 +409,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -601,7 +601,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): one_hot_column, embedding_column, real_valued_column]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10]) def testRealValuedColumn(self): @@ -714,7 +714,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_column]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]], output.eval()) @@ -732,7 +732,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]], output.eval()) @@ -750,7 +750,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], output.eval()) @@ -784,7 +784,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual([3, 10], output.eval().shape) def testEmbeddingColumnSucceedsForDNN(self): @@ -891,7 +891,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self): @@ -914,7 +914,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self): @@ -965,7 +965,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "Error creating input layer for column: ids_weighted_by_weights"): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() feature_column_ops.input_from_feature_columns(features, [weighted_ids]) def testCrossedColumnFailsForDNN(self): @@ -1072,7 +1072,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # score: (sum of weights) self.assertAllEqual(output.eval(), [[10.], [50.], [0.]]) @@ -1310,7 +1310,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = np.array([4, 3, 4]) @@ -1344,7 +1344,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = np.array([4, 3, hash_buckets]) @@ -1374,7 +1374,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) self.assertAllEqual(expected_input_shape, model_input.shape) @@ -1403,7 +1403,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) self.assertAllEqual(expected_input_shape, model_input.shape) @@ -1433,7 +1433,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): embedding_weights) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input, gradients = sess.run([model_input_tensor, gradient_tensor]) expected_input_shape = [4, 3, embedding_dimension] @@ -1500,7 +1500,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = [ @@ -1581,7 +1581,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_ids], num_outputs=5) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) def testWeightedSparseColumnWithDenseInputTensor(self): @@ -1597,7 +1597,7 @@ class WeightedSumTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) def testCrossedColumn(self): @@ -1651,7 +1651,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) with self.test_session() as sess: variables_lib.initialize_all_variables().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] self.assertEqual(weights.get_shape(), (3, 1)) @@ -1726,7 +1726,7 @@ class WeightedSumTest(test.TestCase): features, [age, language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1766,7 +1766,7 @@ class WeightedSumTest(test.TestCase): self.assertEqual(len(variables), 1) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1830,7 +1830,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1858,7 +1858,7 @@ class WeightedSumTest(test.TestCase): features, [language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # score: 0.1 + language_weight['hindi'] + language_weight['english'] sess.run(bias.assign([0.1])) @@ -1881,7 +1881,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] self.assertEqual(weights.get_shape(), (15, 1)) @@ -1915,7 +1915,7 @@ class WeightedSumTest(test.TestCase): features, [country_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1939,7 +1939,7 @@ class WeightedSumTest(test.TestCase): features, [language_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[language_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1972,7 +1972,7 @@ class WeightedSumTest(test.TestCase): features, [country_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language][0] sess.run(weights.assign(weights + 0.4)) @@ -2013,7 +2013,7 @@ class WeightedSumTest(test.TestCase): scope=scope)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertEqual(2, len(column_to_variable[country])) self.assertEqual(3, len(column_to_variable[language])) @@ -2050,7 +2050,7 @@ class WeightedSumTest(test.TestCase): features, [country, age, incomes], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() incomes_weights = column_to_variable[incomes][0] sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]])) @@ -2086,7 +2086,7 @@ class WeightedSumTest(test.TestCase): features, [country, age, height, incomes], num_outputs=5)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() height_weights = column_to_variable[height][0] sess.run( @@ -2116,7 +2116,7 @@ class WeightedSumTest(test.TestCase): features, [bucket], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3], [0.4]])) @@ -2144,7 +2144,7 @@ class WeightedSumTest(test.TestCase): features, [bucket, country], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # dimension = 2, bucket_size = 4, num_classes = 1 sess.run(column_to_variable[bucket][0].assign( @@ -2173,7 +2173,7 @@ class WeightedSumTest(test.TestCase): features, [bucket, country], num_outputs=5)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() # dimension = 2, bucket_size = 4, num_classes = 5 sess.run(column_to_variable[bucket][0].assign( @@ -2209,7 +2209,7 @@ class WeightedSumTest(test.TestCase): features, [country_price], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_price][0] sess.run(weights.assign(weights + 0.4)) @@ -2248,7 +2248,7 @@ class WeightedSumTest(test.TestCase): features, [country_language_price], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[country_language_price][0] sess.run(weights.assign(weights + 0.4)) @@ -2272,7 +2272,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2287,7 +2287,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2302,7 +2302,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.6], [0.7]]) @@ -2323,7 +2323,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2335,7 +2335,7 @@ class WeightedSumTest(test.TestCase): features, [feature_column.real_valued_column("age")], num_outputs=3) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() sess.run(bias.assign([0.1, 0.2, 0.3])) self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) @@ -2349,7 +2349,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (1, 3)) sess.run(weights.assign([[0.01, 0.03, 0.05]])) @@ -2373,7 +2373,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) sess.run( @@ -2399,7 +2399,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2439,7 +2439,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2468,7 +2468,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2533,7 +2533,7 @@ class ParseExampleTest(test.TestCase): self.assertIn(bucket, output) self.assertIn(wire_cast, output) with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]]) self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]]) self.assertAllEqual(output[wire_cast].values.eval(), [2, 0]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 43b3d2a78fc..58072500d10 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -157,7 +157,7 @@ class DynamicRnnEstimatorTest(test.TestCase): self.context_feature_columns) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) sequence_input_val = sess.run(sequence_input) expected_shape = np.array([ 3, # expected batch size @@ -178,7 +178,7 @@ class DynamicRnnEstimatorTest(test.TestCase): # Obtain values of activations and final state. with session.Session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) activations, final_state = sess.run([activations_t, final_state_t]) expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 74a6da20d4e..36f843ba8e7 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -57,7 +57,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -1292,7 +1292,7 @@ class Estimator(BaseEstimator): init_op = control_flow_ops.group( variables.local_variables_initializer(), resources.initialize_resources(resources.shared_resources()), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 207a189a94d..d5777088de7 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -32,7 +32,7 @@ from tensorflow.core.framework import summary_pb2 from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses as losses_lib from tensorflow.python.platform import test @@ -1214,7 +1214,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.), (0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( [0, 2], model_fn_ops.predictions["classes"].eval()) @@ -1266,7 +1266,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.), (0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( [b"key0", b"key2"], model_fn_ops.predictions["classes"].eval()) @@ -1301,7 +1301,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((1., 0., 0.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) _assert_summary_tags(self, ["loss"]) @@ -1327,7 +1327,7 @@ class MultiClassHeadTest(test.TestCase): train_op_fn=head_lib.no_op_train_fn, logits=((0., 0., 1.),)) with session.Session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertIsNone(model_fn_ops.train_op) _assert_no_variables(self) _assert_summary_tags(self, ["loss"]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py index f5bd03429c6..feea6c5fed3 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py @@ -35,8 +35,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -55,7 +55,7 @@ class PrepareInputsForRnnTest(test.TestCase): with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.initialize_all_tables()) + sess.run(lookup_ops.tables_initializer()) features_val = sess.run(features_by_time) self.assertAllEqual(expected, features_val) @@ -316,7 +316,7 @@ class StateSavingRnnEstimatorTest(test.TestCase): with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.initialize_all_tables()) + sess.run(lookup_ops.tables_initializer()) actual_sequence, actual_context = sess.run( [sequence, context]) assert_equal(expected_sequence, actual_sequence) diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 4b7867f2d00..98365c05f66 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -37,8 +37,8 @@ from tensorflow.python.client import session as tf_session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -429,11 +429,14 @@ def _get_ready_op(): def _get_local_init_op(): + """Returns the local init ops to initialize tables and local variables.""" local_init_op = _get_first_op_from_collection( ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: - op_list = [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()] + op_list = [ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) @@ -680,7 +683,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None): else: session.run(variables.global_variables_initializer()) session.run(variables.local_variables_initializer()) - session.run(data_flow_ops.tables_initializer()) + session.run(lookup_ops.tables_initializer()) coord = coordinator.Coordinator() threads = None try: diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index b53be292830..36a1f5f60cd 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver @@ -67,17 +67,17 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, with graph.as_default(): with tf_session.Session('') as session: variables.local_variables_initializer() - data_flow_ops.tables_initializer() + lookup_ops.tables_initializer() saver.restore(session, checkpoint_path) export = exporter.Exporter(saver) - export.init(init_op=control_flow_ops.group( - variables.local_variables_initializer(), - data_flow_ops.tables_initializer()), - default_graph_signature=default_graph_signature, - named_graph_signatures=named_graph_signatures, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS)) + export.init( + init_op=control_flow_ops.group( + variables.local_variables_initializer(), + lookup_ops.tables_initializer()), + default_graph_signature=default_graph_signature, + named_graph_signatures=named_graph_signatures, + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) return export.export(export_dir, contrib_variables.get_global_step(), session, exports_to_keep=exports_to_keep) diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index 5966c86dfb9..bbbd3403526 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -30,11 +30,11 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:data_flow_ops", "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lookup_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 0ec40a63f26..5ec169b6db4 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver @@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase): table3 = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual(3, table1.size().eval()) self.assertAllEqual(3, table2.size().eval()) self.assertAllEqual(3, table3.size().eval()) @@ -1184,7 +1184,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int32_index_table_from_file(self): @@ -1198,7 +1198,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_file(self): @@ -1212,7 +1212,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_index_table_from_file_with_default_value(self): @@ -1224,7 +1224,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_file_with_oov_buckets(self): @@ -1236,7 +1236,7 @@ class IndexTableFromFile(test.TestCase): constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual( ( 1, # From vocabulary file. @@ -1259,7 +1259,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, -1, -1), ids.eval()) self.assertEqual(2, table.size().eval()) @@ -1286,7 +1286,7 @@ class IndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), ids.eval()) self.assertEqual(3, table.size().eval()) @@ -1345,7 +1345,7 @@ class IndexTableFromTensor(test.TestCase): ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int32_index_table_from_tensor_with_tensor_init(self): @@ -1356,7 +1356,7 @@ class IndexTableFromTensor(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_int64_index_table_from_tensor_with_tensor_init(self): @@ -1367,7 +1367,7 @@ class IndexTableFromTensor(test.TestCase): constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_index_table_from_tensor_with_default_value(self): @@ -1378,7 +1378,7 @@ class IndexTableFromTensor(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) def test_index_table_from_tensor_missing_mapping(self): @@ -1394,7 +1394,7 @@ class IndexTableFromTensor(test.TestCase): self.assertRaises(errors_impl.OpError, ids.eval) with self.assertRaisesRegexp( errors_impl.OpError, "keys and values cannot be empty"): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() def test_index_table_from_tensor_with_invalid_hashers(self): with self.test_session(): @@ -1422,7 +1422,7 @@ class StringToIndexTest(test.TestCase): indices = lookup.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, indices.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), indices.eval()) @@ -1433,7 +1433,7 @@ class StringToIndexTest(test.TestCase): _ = lookup.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, - data_flow_ops.tables_initializer().run) + lookup_ops.tables_initializer().run) def test_string_to_index_with_default_value(self): default_value = -42 @@ -1444,7 +1444,7 @@ class StringToIndexTest(test.TestCase): feats, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, indices.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), indices.eval()) @@ -1463,7 +1463,7 @@ class IndexToStringTableFromFileTest(test.TestCase): vocabulary_file=vocabulary_file) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -1475,7 +1475,7 @@ class IndexToStringTableFromFileTest(test.TestCase): vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -1489,7 +1489,7 @@ class IndexToStringTableFromFileTest(test.TestCase): default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", default_value, default_value), features.eval()) @@ -1501,7 +1501,7 @@ class IndexToStringTableFromFileTest(test.TestCase): features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - init = data_flow_ops.tables_initializer() + init = lookup_ops.tables_initializer() self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Invalid vocab_size", init.run) @@ -1513,7 +1513,7 @@ class IndexToStringTableFromFileTest(test.TestCase): features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval()) @@ -1528,7 +1528,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) features = table.lookup(indices) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -1540,7 +1540,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): mapping=mapping_strings) indices = constant_op.constant([0, 1, 4], dtypes.int64) features = table.lookup(indices) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval()) def test_index_to_string_with_default_value(self): @@ -1553,7 +1553,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): features = table.lookup(indices) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -1567,7 +1567,7 @@ class IndexToStringTest(test.TestCase): feats = lookup.index_to_string(indices, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, feats.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), feats.eval()) @@ -1577,11 +1577,11 @@ class IndexToStringTest(test.TestCase): mapping_strings = constant_op.constant(["hello", "hello"]) indices = constant_op.constant([0, 1, 4], dtypes.int64) feats = lookup.index_to_string(indices, mapping=mapping_strings) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval()) self.assertRaises(errors_impl.OpError, - data_flow_ops.tables_initializer().run) + lookup_ops.tables_initializer().run) def test_index_to_string_with_default_value(self): default_value = b"NONE" @@ -1592,7 +1592,7 @@ class IndexToStringTest(test.TestCase): indices, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, feats.eval) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval()) @@ -1755,7 +1755,7 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value, shared_name=shared_name) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -2081,7 +2081,7 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup.StrongHashSpec((1, 2)), name="table2") - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string = constant_op.constant( ["fruit", "brain", "salad", "surgery", "UNK"]) @@ -2167,7 +2167,7 @@ class IdTableWithHashBucketsTest(test.TestCase): default_value2), oov_buckets) - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() input_string_1 = constant_op.constant( ["brain", "salad", "surgery", "UNK"]) diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 5ced8a4f089..b70d612f55b 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -261,7 +261,7 @@ from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging @@ -657,7 +657,7 @@ def train(train_op, if local_init_op == _USE_DEFAULT: local_init_op = control_flow_ops.group( tf_variables.local_variables_initializer(), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) if sync_optimizer is not None and isinstance( sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 119bc0f8997..435618ace7a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -506,6 +506,7 @@ tf_gen_op_libs( "image_ops", "io_ops", "linalg_ops", + "lookup_ops", "logging_ops", "math_ops", "nn_ops", @@ -582,6 +583,7 @@ cc_library( ":image_ops_op_lib", ":io_ops_op_lib", ":linalg_ops_op_lib", + ":lookup_ops_op_lib", ":logging_ops_op_lib", ":math_ops_op_lib", ":nn_ops_op_lib", @@ -708,6 +710,7 @@ cc_library( "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", + "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:math", "//tensorflow/core/kernels:multinomial_op", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0847d1279b8..02ab30a04fa 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1327,6 +1327,14 @@ cc_library( ], ) +cc_library( + name = "lookup", + deps = [ + ":lookup_table_init_op", + ":lookup_table_op", + ], +) + DATA_FLOW_DEPS = [ ":bounds_check", ":concat_lib", @@ -1450,10 +1458,10 @@ LOOKUP_DEPS = [ ":initializable_lookup_table", ":lookup_util", "//tensorflow/core:core_cpu", - "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:lookup_ops_op_lib", ] tf_kernel_library( diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index f35a1bb6489..032ede6459c 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1876,604 +1876,6 @@ size: The number of incomplete elements (i.e. those with some of their value // -------------------------------------------------------------------------- -REGISTER_OP("LookupTableFind") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("default_value: Tout") - .Output("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // Default value must be scalar or vector. - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); - c->set_output(0, c->UnknownShape()); - return Status::OK(); - }) - .Doc(R"doc( -Looks up keys in a table, outputs the corresponding values. - -The tensor `keys` must of the same type as the keys of the table. -The output `values` is of the type of the table values. - -The scalar `default_value` is the value output for keys not present in the -table. It must also be of the same type as the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Same shape as `keys`. Values found in the table, or `default_values` - for missing keys. -)doc"); - -REGISTER_OP("LookupTableFindV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("default_value: Tout") - .Output("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // Default value must be scalar or vector. - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); - c->set_output(0, c->UnknownShape()); - return Status::OK(); - }) - .Doc(R"doc( -Looks up keys in a table, outputs the corresponding values. - -The tensor `keys` must of the same type as the keys of the table. -The output `values` is of the type of the table values. - -The scalar `default_value` is the value output for keys not present in the -table. It must also be of the same type as the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Same shape as `keys`. Values found in the table, or `default_values` - for missing keys. -)doc"); - -REGISTER_OP("LookupTableInsert") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // TODO(ebrevdo): Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Updates the table to associates keys with values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableInsertV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // TODO: Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Updates the table to associates keys with values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableSize") - .Input("table_handle: Ref(string)") - .Output("size: int64") - .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) - .Doc(R"doc( -Computes the number of elements in the given table. - -table_handle: Handle to the table. -size: Scalar that contains number of elements in the table. -)doc"); - -REGISTER_OP("LookupTableSizeV2") - .Input("table_handle: resource") - .Output("size: int64") - .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs) - .Doc(R"doc( -Computes the number of elements in the given table. - -table_handle: Handle to the table. -size: Scalar that contains number of elements in the table. -)doc"); - -REGISTER_OP("LookupTableExport") - .Input("table_handle: Ref(string)") - .Output("keys: Tkeys") - .Output("values: Tvalues") - .Attr("Tkeys: type") - .Attr("Tvalues: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->Vector(c->Dim(values, 0)); - c->set_output(0, keys); - c->set_output(1, values); - return Status::OK(); - }) - .Doc(R"doc( -Outputs all keys and values in the table. - -table_handle: Handle to the table. -keys: Vector of all keys present in the table. -values: Tensor of all values in the table. Indexed in parallel with `keys`. -)doc"); - -REGISTER_OP("LookupTableExportV2") - .Input("table_handle: resource") - .Output("keys: Tkeys") - .Output("values: Tvalues") - .Attr("Tkeys: type") - .Attr("Tvalues: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->Vector(c->Dim(values, 0)); - c->set_output(0, keys); - c->set_output(1, values); - return Status::OK(); - }) - .Doc(R"doc( -Outputs all keys and values in the table. - -table_handle: Handle to the table. -keys: Vector of all keys present in the table. -values: Tensor of all values in the table. Indexed in parallel with `keys`. -)doc"); - -REGISTER_OP("LookupTableImport") - .Input("table_handle: Ref(string)") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - // TODO(ebrevdo): Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Replaces the contents of the table with the specified keys and values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("LookupTableImportV2") - .Input("table_handle: resource") - .Input("keys: Tin") - .Input("values: Tout") - .Attr("Tin: type") - .Attr("Tout: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - // TODO: Validate keys and values shape. - return Status::OK(); - }) - .Doc(R"doc( -Replaces the contents of the table with the specified keys and values. - -The tensor `keys` must be of the same type as the keys of the table. -The tensor `values` must be of the type of the table values. - -table_handle: Handle to the table. -keys: Any shape. Keys to look up. -values: Values to associate with keys. -)doc"); - -REGISTER_OP("HashTable") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates a non-initialized hash table. - -This op creates a hash table, specifying the type of its keys and values. -Before using the table you will have to initialize it. After initialization the -table will be immutable. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("HashTableV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates a non-initialized hash table. - -This op creates a hash table, specifying the type of its keys and values. -Before using the table you will have to initialize it. After initialization the -table will be immutable. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTable") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -use_node_name_sharing: If true and shared_name is empty, the table is shared - using the node name. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableOfTensors") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a vector. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableHashTableOfTensorsV2") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a vector. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -)doc"); - -REGISTER_OP("MutableDenseHashTable") - .Input("empty_key: key_dtype") - .Output("table_handle: Ref(string)") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .Attr("initial_num_buckets: int = 131072") // 2^17 - .Attr("max_load_factor: float = 0.8") - .SetIsStateful() - .SetShapeFn(TwoElementOutput) - .Doc(R"doc( -Creates an empty hash table that uses tensors as the backing store. It uses -"open addressing" with quadratic reprobing to resolve collisions. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -empty_key: The key used to represent empty key buckets internally. Must not - be used in insert or lookup operations. -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -value_shape: The shape of each value. -initial_num_buckets: The initial number of hash table buckets. Must be a power - to 2. -max_load_factor: The maximum ratio between number of entries and number of - buckets before growing the table. Must be between 0 and 1. -)doc"); - -REGISTER_OP("MutableDenseHashTableV2") - .Input("empty_key: key_dtype") - .Output("table_handle: resource") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("use_node_name_sharing: bool = false") - .Attr("key_dtype: type") - .Attr("value_dtype: type") - .Attr("value_shape: shape = {}") - .Attr("initial_num_buckets: int = 131072") // 2^17 - .Attr("max_load_factor: float = 0.8") - .SetIsStateful() - .SetShapeFn(ScalarOutput) - .Doc(R"doc( -Creates an empty hash table that uses tensors as the backing store. It uses -"open addressing" with quadratic reprobing to resolve collisions. - -This op creates a mutable hash table, specifying the type of its keys and -values. Each value must be a scalar. Data can be inserted into the table using -the insert operations. It does not support the initialization operation. - -empty_key: The key used to represent empty key buckets internally. Must not - be used in insert or lookup operations. -table_handle: Handle to a table. -container: If non-empty, this table is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this table is shared under the given name across - multiple sessions. -key_dtype: Type of the table keys. -value_dtype: Type of the table values. -value_shape: The shape of each value. -initial_num_buckets: The initial number of hash table buckets. Must be a power - to 2. -max_load_factor: The maximum ratio between number of entries and number of - buckets before growing the table. Must be between 0 and 1. -)doc"); - -REGISTER_OP("InitializeTable") - .Input("table_handle: Ref(string)") - .Input("keys: Tkey") - .Input("values: Tval") - .Attr("Tkey: type") - .Attr("Tval: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - ShapeHandle keys; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); - TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); - return Status::OK(); - }) - .Doc(R"doc( -Table initializer that takes two tensors for keys and values respectively. - -table_handle: Handle to a table which will be initialized. -keys: Keys of type Tkey. -values: Values of type Tval. -)doc"); - -REGISTER_OP("InitializeTableV2") - .Input("table_handle: resource") - .Input("keys: Tkey") - .Input("values: Tval") - .Attr("Tkey: type") - .Attr("Tval: type") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - ShapeHandle keys; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); - TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); - return Status::OK(); - }) - .Doc(R"doc( -Table initializer that takes two tensors for keys and values respectively. - -table_handle: Handle to a table which will be initialized. -keys: Keys of type Tkey. -values: Values of type Tval. -)doc"); - -REGISTER_OP("InitializeTableFromTextFile") - .Input("table_handle: Ref(string)") - .Input("filename: string") - .Attr("key_index: int >= -2") - .Attr("value_index: int >= -2") - .Attr("vocab_size: int >= -1 = -1") - .Attr("delimiter: string = '\t'") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); - - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); - return Status::OK(); - }) - .Doc(R"doc( -Initializes a table from a text file. - -It inserts one key-value pair into the table for each line of the file. -The key and value is extracted from the whole line content, elements from the -split line based on `delimiter` or the line number (starting from zero). -Where to extract the key and value from a line is specified by `key_index` and -`value_index`. - -- A value of -1 means use the line number(starting from zero), expects `int64`. -- A value of -2 means use the whole line content, expects `string`. -- A value >= 0 means use the index (starting at zero) of the split line based - on `delimiter`. - -table_handle: Handle to a table which will be initialized. -filename: Filename of a vocabulary text file. -key_index: Column index in a line to get the table `key` values from. -value_index: Column index that represents information of a line to get the table - `value` values from. -vocab_size: Number of elements of the file, use -1 if unknown. -delimiter: Delimiter to separate fields in a line. -)doc"); - -REGISTER_OP("InitializeTableFromTextFileV2") - .Input("table_handle: resource") - .Input("filename: string") - .Attr("key_index: int >= -2") - .Attr("value_index: int >= -2") - .Attr("vocab_size: int >= -1 = -1") - .Attr("delimiter: string = '\t'") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); - return Status::OK(); - }) - .Doc(R"doc( -Initializes a table from a text file. - -It inserts one key-value pair into the table for each line of the file. -The key and value is extracted from the whole line content, elements from the -split line based on `delimiter` or the line number (starting from zero). -Where to extract the key and value from a line is specified by `key_index` and -`value_index`. - -- A value of -1 means use the line number(starting from zero), expects `int64`. -- A value of -2 means use the whole line content, expects `string`. -- A value >= 0 means use the index (starting at zero) of the split line based - on `delimiter`. - -table_handle: Handle to a table which will be initialized. -filename: Filename of a vocabulary text file. -key_index: Column index in a line to get the table `key` values from. -value_index: Column index that represents information of a line to get the table - `value` values from. -vocab_size: Number of elements of the file, use -1 if unknown. -delimiter: Delimiter to separate fields in a line. -)doc"); - REGISTER_OP("GetSessionHandle") .Input("value: T") .Output("handle: string") diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc new file mode 100644 index 00000000000..498a65690d0 --- /dev/null +++ b/tensorflow/core/ops/lookup_ops.cc @@ -0,0 +1,666 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- + +namespace { +Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_handle; + for (int i = 0; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + +Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + for (int i = 1; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); + } + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + +Status TwoElementOutput(InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); +} + +Status ScalarOutput(InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); +} +} // namespace + +REGISTER_OP("LookupTableFind") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // Default value must be scalar or vector. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +Looks up keys in a table, outputs the corresponding values. + +The tensor `keys` must of the same type as the keys of the table. +The output `values` is of the type of the table values. + +The scalar `default_value` is the value output for keys not present in the +table. It must also be of the same type as the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Same shape as `keys`. Values found in the table, or `default_values` + for missing keys. +)doc"); + +REGISTER_OP("LookupTableFindV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("default_value: Tout") + .Output("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // Default value must be scalar or vector. + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +Looks up keys in a table, outputs the corresponding values. + +The tensor `keys` must of the same type as the keys of the table. +The output `values` is of the type of the table values. + +The scalar `default_value` is the value output for keys not present in the +table. It must also be of the same type as the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Same shape as `keys`. Values found in the table, or `default_values` + for missing keys. +)doc"); + +REGISTER_OP("LookupTableInsert") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Updates the table to associates keys with values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableInsertV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // TODO: Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Updates the table to associates keys with values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableSize") + .Input("table_handle: Ref(string)") + .Output("size: int64") + .SetShapeFn(TwoElementVectorInputsAndScalarOutputs) + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: Handle to the table. +size: Scalar that contains number of elements in the table. +)doc"); + +REGISTER_OP("LookupTableSizeV2") + .Input("table_handle: resource") + .Output("size: int64") + .SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs) + .Doc(R"doc( +Computes the number of elements in the given table. + +table_handle: Handle to the table. +size: Scalar that contains number of elements in the table. +)doc"); + +REGISTER_OP("LookupTableExport") + .Input("table_handle: Ref(string)") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); + ShapeHandle keys = c->Vector(c->Dim(values, 0)); + c->set_output(0, keys); + c->set_output(1, values); + return Status::OK(); + }) + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + +REGISTER_OP("LookupTableExportV2") + .Input("table_handle: resource") + .Output("keys: Tkeys") + .Output("values: Tvalues") + .Attr("Tkeys: type") + .Attr("Tvalues: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle values = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); + ShapeHandle keys = c->Vector(c->Dim(values, 0)); + c->set_output(0, keys); + c->set_output(1, values); + return Status::OK(); + }) + .Doc(R"doc( +Outputs all keys and values in the table. + +table_handle: Handle to the table. +keys: Vector of all keys present in the table. +values: Tensor of all values in the table. Indexed in parallel with `keys`. +)doc"); + +REGISTER_OP("LookupTableImport") + .Input("table_handle: Ref(string)") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + // TODO(ebrevdo): Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Replaces the contents of the table with the specified keys and values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("LookupTableImportV2") + .Input("table_handle: resource") + .Input("keys: Tin") + .Input("values: Tout") + .Attr("Tin: type") + .Attr("Tout: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + // TODO: Validate keys and values shape. + return Status::OK(); + }) + .Doc(R"doc( +Replaces the contents of the table with the specified keys and values. + +The tensor `keys` must be of the same type as the keys of the table. +The tensor `values` must be of the type of the table values. + +table_handle: Handle to the table. +keys: Any shape. Keys to look up. +values: Values to associate with keys. +)doc"); + +REGISTER_OP("HashTable") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates a non-initialized hash table. + +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("HashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates a non-initialized hash table. + +This op creates a hash table, specifying the type of its keys and values. +Before using the table you will have to initialize it. After initialization the +table will be immutable. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTable") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +use_node_name_sharing: If true and shared_name is empty, the table is shared + using the node name. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableOfTensors") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableHashTableOfTensorsV2") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a vector. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +)doc"); + +REGISTER_OP("MutableDenseHashTable") + .Input("empty_key: key_dtype") + .Output("table_handle: Ref(string)") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("initial_num_buckets: int = 131072") // 2^17 + .Attr("max_load_factor: float = 0.8") + .SetIsStateful() + .SetShapeFn(TwoElementOutput) + .Doc(R"doc( +Creates an empty hash table that uses tensors as the backing store. It uses +"open addressing" with quadratic reprobing to resolve collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +empty_key: The key used to represent empty key buckets internally. Must not + be used in insert or lookup operations. +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +value_shape: The shape of each value. +initial_num_buckets: The initial number of hash table buckets. Must be a power + to 2. +max_load_factor: The maximum ratio between number of entries and number of + buckets before growing the table. Must be between 0 and 1. +)doc"); + +REGISTER_OP("MutableDenseHashTableV2") + .Input("empty_key: key_dtype") + .Output("table_handle: resource") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("use_node_name_sharing: bool = false") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("value_shape: shape = {}") + .Attr("initial_num_buckets: int = 131072") // 2^17 + .Attr("max_load_factor: float = 0.8") + .SetIsStateful() + .SetShapeFn(ScalarOutput) + .Doc(R"doc( +Creates an empty hash table that uses tensors as the backing store. It uses +"open addressing" with quadratic reprobing to resolve collisions. + +This op creates a mutable hash table, specifying the type of its keys and +values. Each value must be a scalar. Data can be inserted into the table using +the insert operations. It does not support the initialization operation. + +empty_key: The key used to represent empty key buckets internally. Must not + be used in insert or lookup operations. +table_handle: Handle to a table. +container: If non-empty, this table is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this table is shared under the given name across + multiple sessions. +key_dtype: Type of the table keys. +value_dtype: Type of the table values. +value_shape: The shape of each value. +initial_num_buckets: The initial number of hash table buckets. Must be a power + to 2. +max_load_factor: The maximum ratio between number of entries and number of + buckets before growing the table. Must be between 0 and 1. +)doc"); + +REGISTER_OP("InitializeTable") + .Input("table_handle: Ref(string)") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }) + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: Handle to a table which will be initialized. +keys: Keys of type Tkey. +values: Values of type Tval. +)doc"); + +REGISTER_OP("InitializeTableV2") + .Input("table_handle: resource") + .Input("keys: Tkey") + .Input("values: Tval") + .Attr("Tkey: type") + .Attr("Tval: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + ShapeHandle keys; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); + TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); + return Status::OK(); + }) + .Doc(R"doc( +Table initializer that takes two tensors for keys and values respectively. + +table_handle: Handle to a table which will be initialized. +keys: Keys of type Tkey. +values: Values of type Tval. +)doc"); + +REGISTER_OP("InitializeTableFromTextFile") + .Input("table_handle: Ref(string)") + .Input("filename: string") + .Attr("key_index: int >= -2") + .Attr("value_index: int >= -2") + .Attr("vocab_size: int >= -1 = -1") + .Attr("delimiter: string = '\t'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); + }) + .Doc(R"doc( +Initializes a table from a text file. + +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + +table_handle: Handle to a table which will be initialized. +filename: Filename of a vocabulary text file. +key_index: Column index in a line to get the table `key` values from. +value_index: Column index that represents information of a line to get the table + `value` values from. +vocab_size: Number of elements of the file, use -1 if unknown. +delimiter: Delimiter to separate fields in a line. +)doc"); + +REGISTER_OP("InitializeTableFromTextFileV2") + .Input("table_handle: resource") + .Input("filename: string") + .Attr("key_index: int >= -2") + .Attr("value_index: int >= -2") + .Attr("vocab_size: int >= -1 = -1") + .Attr("delimiter: string = '\t'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); + + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle)); + return Status::OK(); + }) + .Doc(R"doc( +Initializes a table from a text file. + +It inserts one key-value pair into the table for each line of the file. +The key and value is extracted from the whole line content, elements from the +split line based on `delimiter` or the line number (starting from zero). +Where to extract the key and value from a line is specified by `key_index` and +`value_index`. + +- A value of -1 means use the line number(starting from zero), expects `int64`. +- A value of -2 means use the whole line content, expects `string`. +- A value >= 0 means use the index (starting at zero) of the split line based + on `delimiter`. + +table_handle: Handle to a table which will be initialized. +filename: Filename of a vocabulary text file. +key_index: Column index in a line to get the table `key` values from. +value_index: Column index that represents information of a line to get the table + `value` values from. +vocab_size: Number of elements of the file, use -1 if unknown. +delimiter: Delimiter to separate fields in a line. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 817d157da29..9fd5ada71ee 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1022,7 +1022,6 @@ tf_gen_op_wrapper_private_py( require_shape_functions = True, visibility = [ "//learning/brain/python/ops:__pkg__", - "//tensorflow/python/feature_column:__pkg__", "//tensorflow/python/kernel_tests:__pkg__", ], ) @@ -1057,6 +1056,16 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "lookup_ops_gen", + require_shape_functions = True, + visibility = [ + "//learning/brain/python/ops:__pkg__", + "//tensorflow/python/feature_column:__pkg__", + "//tensorflow/python/kernel_tests:__pkg__", + ], +) + tf_gen_op_wrapper_private_py( name = "math_ops_gen", require_shape_functions = True, @@ -1474,6 +1483,20 @@ py_library( ], ) +py_library( + name = "lookup_ops", + srcs = ["ops/lookup_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":framework", + ":framework_for_generated_wrappers", + ":lookup_ops_gen", + ":math_ops", + "@six_archive//:six", + ], +) + py_library( name = "math_grad", srcs = ["ops/math_grad.py"], @@ -1862,6 +1885,7 @@ py_library( ":io_ops", ":linalg_ops", ":logging_ops", + ":lookup_ops", ":math_grad", ":math_ops", ":numerics", @@ -2269,6 +2293,7 @@ py_library( ":io_ops", ":io_ops_gen", ":lib", + ":lookup_ops", ":math_ops", ":platform", ":protos_all_py", @@ -2991,6 +3016,7 @@ cuda_py_tests( ":framework", ":framework_for_generated_wrappers", ":framework_test_lib", + ":lookup_ops", ":gradients", ":math_ops", ":nn_grad", @@ -3021,7 +3047,7 @@ py_library( srcs = ["training/saver_test_utils.py"], srcs_version = "PY2AND3", deps = [ - ":data_flow_ops_gen", + ":lookup_ops_gen", ":training", ], ) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index f70c285f049..b8064f0a776 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -42,8 +42,8 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import state_ops @@ -1391,9 +1391,10 @@ class EstimatorExportTest(test.TestCase): my_int = variables.Variable(1, name='my_int', collections=[ops.GraphKeys.LOCAL_VARIABLES]) scores = constant_op.constant([3.]) - with ops.control_dependencies( - [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()]): + with ops.control_dependencies([ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ]): assign_op = state_ops.assign(my_int, 12345) # local_initSop must be an Operation, not a Tensor. diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index d7342738457..ac7aef96ac1 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -80,9 +80,9 @@ py_library( deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops_gen", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lookup_ops_gen", "//tensorflow/python:math_ops", "//tensorflow/python:string_ops", "//tensorflow/python:training", diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index d85142abcfb..ad67a082dc9 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -41,7 +41,7 @@ from tensorflow.python.platform import test def _initialized_session(): sess = session.Session() sess.run(variables_lib.global_variables_initializer()) - sess.run(data_flow_ops.tables_initializer()) + sess.run(lookup_ops.tables_initializer()) return sess @@ -1277,7 +1277,7 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'): with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() def test_invalid_vocabulary_size(self): with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): @@ -1307,7 +1307,7 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'): with self.test_session(): - data_flow_ops.tables_initializer().run() + lookup_ops.tables_initializer().run() def test_invalid_num_oov_buckets(self): with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'): diff --git a/tensorflow/python/feature_column/lookup_ops.py b/tensorflow/python/feature_column/lookup_ops.py index 13a67fa5183..8225b47b204 100644 --- a/tensorflow/python/feature_column/lookup_ops.py +++ b/tensorflow/python/feature_column/lookup_ops.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.training.saver import BaseSaverBuilder @@ -151,7 +151,7 @@ class InitializableLookupTableBase(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as scope: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=scope) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=scope) # pylint: enable=protected-access def lookup(self, keys, name=None): @@ -182,7 +182,7 @@ class InitializableLookupTableBase(LookupInterface): name, "%s_Lookup" % self._name, (self._table_ref, key_tensor, self._default_value)) as scope: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find( + values = gen_lookup_ops._lookup_table_find( self._table_ref, key_tensor, self._default_value, name=scope) # pylint: enable=protected-access @@ -229,7 +229,7 @@ class HashTable(InitializableLookupTableBase): with ops.name_scope( name, "hash_table", (initializer, default_value)) as scope: # pylint: disable=protected-access - table_ref = gen_data_flow_ops._hash_table( + table_ref = gen_lookup_ops._hash_table( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, @@ -308,10 +308,8 @@ class KeyValueTensorInitializer(TableInitializerBase): self._name, values=(table.table_ref, self._keys, self._values)) as scope: # pylint: disable=protected-access - init_op = gen_data_flow_ops._initialize_table(table.table_ref, - self._keys, - self._values, - name=scope) + init_op = gen_lookup_ops._initialize_table( + table.table_ref, self._keys, self._values, name=scope) # pylint: enable=protected-access ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op @@ -477,7 +475,7 @@ class TextFileInitializer(TableInitializerBase): dtypes.string, name="asset_filepath") # pylint: disable=protected-access - init_op = gen_data_flow_ops._initialize_table_from_text_file( + init_op = gen_lookup_ops._initialize_table_from_text_file( table.table_ref, filename, self._key_index, @@ -1333,14 +1331,14 @@ class MutableHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None # pylint: disable=protected-access if self._default_value.get_shape().ndims == 0: - self._table_ref = gen_data_flow_ops._mutable_hash_table( + self._table_ref = gen_lookup_ops._mutable_hash_table( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: - self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors( + self._table_ref = gen_lookup_ops._mutable_hash_table_of_tensors( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, @@ -1368,7 +1366,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -1394,10 +1392,8 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, (self._table_ref, keys, self._default_value)) as name: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find(self._table_ref, - keys, - self._default_value, - name=name) + values = gen_lookup_ops._lookup_table_find( + self._table_ref, keys, self._default_value, name=name) values.set_shape(keys.get_shape().concatenate(self._value_shape)) return values @@ -1423,7 +1419,7 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: # pylint: disable=protected-access - op = gen_data_flow_ops._lookup_table_insert( + op = gen_lookup_ops._lookup_table_insert( self._table_ref, keys, values, name=name) return op @@ -1440,11 +1436,8 @@ class MutableHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - exported_keys, exported_values = gen_data_flow_ops._lookup_table_export( - self._table_ref, - self._key_dtype, - self._value_dtype, - name=name) + exported_keys, exported_values = gen_lookup_ops._lookup_table_export( + self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( self._value_shape)) @@ -1464,7 +1457,7 @@ class MutableHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_import( + return gen_lookup_ops._lookup_table_import( self.op._table_ref, restored_tensors[0], restored_tensors[1]) @@ -1539,7 +1532,7 @@ class MutableDenseHashTable(LookupInterface): use_node_name_sharing = checkpoint and shared_name is None empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) # pylint: disable=protected-access - self._table_ref = gen_data_flow_ops._mutable_dense_hash_table( + self._table_ref = gen_lookup_ops._mutable_dense_hash_table( empty_key=empty_key, shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, @@ -1567,7 +1560,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_Size" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name) + return gen_lookup_ops._lookup_table_size(self._table_ref, name=name) def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. @@ -1593,7 +1586,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_find" % self._name, [self._table_ref, keys]) as name: # pylint: disable=protected-access - values = gen_data_flow_ops._lookup_table_find( + values = gen_lookup_ops._lookup_table_find( self._table_ref, keys, self._default_value, name=name) if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: @@ -1623,7 +1616,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: # pylint: disable=protected-access - op = gen_data_flow_ops._lookup_table_insert( + op = gen_lookup_ops._lookup_table_insert( self._table_ref, keys, values, name=name) return op @@ -1640,7 +1633,7 @@ class MutableDenseHashTable(LookupInterface): with ops.name_scope(name, "%s_lookup_table_export_values" % self._name, [self._table_ref]) as name: # pylint: disable=protected-access - exported_keys, exported_values = gen_data_flow_ops._lookup_table_export( + exported_keys, exported_values = gen_lookup_ops._lookup_table_export( self._table_ref, self._key_dtype, self._value_dtype, name=name) exported_values.set_shape(exported_keys.get_shape().concatenate( @@ -1661,6 +1654,5 @@ class MutableDenseHashTable(LookupInterface): def restore(self, restored_tensors, unused_restored_shapes): # pylint: disable=protected-access - return gen_data_flow_ops._lookup_table_import(self.op._table_ref, - restored_tensors[0], - restored_tensors[1]) + return gen_lookup_ops._lookup_table_import( + self.op._table_ref, restored_tensors[0], restored_tensors[1]) diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 95e803e2aa0..9a208613add 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -38,7 +38,6 @@ from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * # pylint: enable=wildcard-import -from tensorflow.python.util.deprecation import deprecated def _as_type_list(dtypes): @@ -1037,47 +1036,6 @@ class Barrier(object): self._barrier_ref, name=name) -@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.") -def initialize_all_tables(name="init_all_tables"): - """Returns an Op that initializes all tables of the default graph. - - Args: - name: Optional name for the initialization op. - - Returns: - An Op that initializes all tables. Note that if there are - not tables the returned Op is a NoOp. - """ - return tables_initializer(name) - - -def tables_initializer(name="init_all_tables"): - """Returns an Op that initializes all tables of the default graph. - - Args: - name: Optional name for the initialization op. - - Returns: - An Op that initializes all tables. Note that if there are - not tables the returned Op is a NoOp. - """ - initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) - if initializers: - return control_flow_ops.group(*initializers, name=name) - return control_flow_ops.no_op(name=name) - - -ops.NotDifferentiable("LookupTableFind") -ops.NotDifferentiable("LookupTableInsert") -ops.NotDifferentiable("LookupTableSize") -ops.NotDifferentiable("HashTable") -ops.NotDifferentiable("InitializeTable") -ops.NotDifferentiable("InitializeTableFromTextFile") -ops.NotDifferentiable("MutableDenseHashTable") -ops.NotDifferentiable("MutableHashTable") -ops.NotDifferentiable("MutableHashTableOfTensors") - - class ConditionalAccumulatorBase(object): """A conditional accumulator for aggregating gradients. diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py new file mode 100644 index 00000000000..54dba9e38eb --- /dev/null +++ b/tensorflow/python/ops/lookup_ops.py @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================== +"""Data Flow Operations.""" +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_lookup_ops import * +# pylint: enable=wildcard-import +from tensorflow.python.util.deprecation import deprecated + + +@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.") +def initialize_all_tables(name="init_all_tables"): + """Returns an Op that initializes all tables of the default graph. + + Args: + name: Optional name for the initialization op. + + Returns: + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + """ + return tables_initializer(name) + + +def tables_initializer(name="init_all_tables"): + """Returns an Op that initializes all tables of the default graph. + + Args: + name: Optional name for the initialization op. + + Returns: + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + """ + initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS) + if initializers: + return control_flow_ops.group(*initializers, name=name) + return control_flow_ops.no_op(name=name) + + +ops.NotDifferentiable("LookupTableFind") +ops.NotDifferentiable("LookupTableFindV2") +ops.NotDifferentiable("LookupTableInsert") +ops.NotDifferentiable("LookupTableInsertV2") +ops.NotDifferentiable("LookupTableSize") +ops.NotDifferentiable("LookupTableSizeV2") +ops.NotDifferentiable("HashTable") +ops.NotDifferentiable("HashTableV2") +ops.NotDifferentiable("InitializeTable") +ops.NotDifferentiable("InitializeTableV2") +ops.NotDifferentiable("InitializeTableFromTextFile") +ops.NotDifferentiable("InitializeTableFromTextFileV2") +ops.NotDifferentiable("MutableDenseHashTable") +ops.NotDifferentiable("MutableDenseHashTableV2") +ops.NotDifferentiable("MutableHashTable") +ops.NotDifferentiable("MutableHashTableV2") +ops.NotDifferentiable("MutableHashTableOfTensors") +ops.NotDifferentiable("MutableHashTableOfTensorsV2") diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 09e04d4247c..a39d28490cc 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -57,6 +57,7 @@ from tensorflow.python.ops.io_ops import * from tensorflow.python.ops.linalg_ops import * from tensorflow.python.ops.logging_ops import Print from tensorflow.python.ops.logging_ops import get_summary_op +from tensorflow.python.ops.lookup_ops import * from tensorflow.python.ops.math_ops import * from tensorflow.python.ops.numerics import * from tensorflow.python.ops.parsing_ops import * diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py index 66cf9d4d8af..355fd57bf1d 100644 --- a/tensorflow/python/saved_model/main_op_impl.py +++ b/tensorflow/python/saved_model/main_op_impl.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables @@ -35,7 +35,7 @@ def main_op(): """ init = variables.global_variables_initializer() init_local = variables.local_variables_initializer() - init_tables = tf_data_flow_ops.tables_initializer() + init_tables = lookup_ops.tables_initializer() return control_flow_ops.group(init, init_local, init_tables) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 4c81af56adb..fcec3ed97c7 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging @@ -238,7 +238,7 @@ class Scaffold(object): @staticmethod def _default_local_init_op(): return control_flow_ops.group(variables.local_variables_initializer(), - data_flow_ops.tables_initializer()) + lookup_ops.tables_initializer()) def MonitoredTrainingSession(master='', # pylint: disable=invalid-name diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py index 5f31e2aa539..6a73565f82b 100644 --- a/tensorflow/python/training/saver_test_utils.py +++ b/tensorflow/python/training/saver_test_utils.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib -from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.training import saver as saver_module @@ -34,7 +34,7 @@ class CheckpointedOp(object): # pylint: disable=protected-access def __init__(self, name, table_ref=None): if table_ref is None: - self.table_ref = gen_data_flow_ops._mutable_hash_table( + self.table_ref = gen_lookup_ops._mutable_hash_table( key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) else: self.table_ref = table_ref @@ -52,10 +52,10 @@ class CheckpointedOp(object): return self._saveable def insert(self, keys, values): - return gen_data_flow_ops._lookup_table_insert(self.table_ref, keys, values) + return gen_lookup_ops._lookup_table_insert(self.table_ref, keys, values) def lookup(self, keys, default): - return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default) + return gen_lookup_ops._lookup_table_find(self.table_ref, keys, default) def keys(self): return self._export()[0] @@ -64,8 +64,8 @@ class CheckpointedOp(object): return self._export()[1] def _export(self): - return gen_data_flow_ops._lookup_table_export(self.table_ref, dtypes.string, - dtypes.float32) + return gen_lookup_ops._lookup_table_export(self.table_ref, dtypes.string, + dtypes.float32) class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): """A custom saveable for CheckpointedOp.""" @@ -81,6 +81,6 @@ class CheckpointedOp(object): super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) def restore(self, restore_tensors, shapes): - return gen_data_flow_ops._lookup_table_import( + return gen_lookup_ops._lookup_table_import( self.op.table_ref, restore_tensors[0], restore_tensors[1]) # pylint: enable=protected-access diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 277c11386dd..230ed1db687 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as _summary @@ -426,8 +426,10 @@ class Supervisor(object): local_init_op = self._get_first_op_from_collection( ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: - op_list = [variables.local_variables_initializer(), - data_flow_ops.tables_initializer()] + op_list = [ + variables.local_variables_initializer(), + lookup_ops.tables_initializer() + ] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) From fd69bb292af7f15cd364e36ead7f596a3c484b2c Mon Sep 17 00:00:00 2001 From: Andrew Harp Date: Thu, 4 May 2017 12:48:11 -0800 Subject: [PATCH 25/43] Android demo: Add YUV -> RGB conversion Java implementation as fallback if native implementation is not found. This means that compiling libtensorflow_demo.so will only be strictly necessary for the Detection example (which uses native object tracking). A followup change will add graceful degradation in that case too. Java conversion may be slower depending on the device, but should still be acceptable for demo purposes as the majority of the compute time will still be spent on TF inference passes. Note that this has no effect on the necessity of libtensorflow_inference.so, which provides the actual TF support. However libtensorflow_inference.so may be added to applications via the prebuilt AAR, so no native compilation is necessary. Partially addresses #6385 Change: 155121431 --- .../tensorflow/demo/ClassifierActivity.java | 3 +- .../org/tensorflow/demo/DetectorActivity.java | 3 +- .../org/tensorflow/demo/StylizeActivity.java | 8 +- .../demo/TensorFlowMultiBoxDetector.java | 4 - .../demo/TensorFlowYoloDetector.java | 4 - .../org/tensorflow/demo/env/ImageUtils.java | 88 ++++++++++++++++++- 6 files changed, 89 insertions(+), 21 deletions(-) diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java index b26a2316782..bc391269255 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java @@ -194,13 +194,12 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab yuvBytes[0], yuvBytes[1], yuvBytes[2], - rgbBytes, previewWidth, previewHeight, yRowStride, uvRowStride, uvPixelStride, - false); + rgbBytes); image.close(); } catch (final Exception e) { diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java index 206a99f3e3d..cdb6c3fed80 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java @@ -273,13 +273,12 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable yuvBytes[0], yuvBytes[1], yuvBytes[2], - rgbBytes, previewWidth, previewHeight, yRowStride, uvRowStride, uvPixelStride, - false); + rgbBytes); image.close(); } catch (final Exception e) { diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java index 7634be5c020..7afe2bf5412 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java @@ -65,10 +65,6 @@ import org.tensorflow.demo.R; * Artistic Style" (https://arxiv.org/abs/1610.07629) */ public class StylizeActivity extends CameraActivity implements OnImageAvailableListener { - static { - System.loadLibrary("tensorflow_demo"); - } - private static final Logger LOGGER = new Logger(); private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb"; @@ -509,17 +505,17 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL final int yRowStride = planes[0].getRowStride(); final int uvRowStride = planes[1].getRowStride(); final int uvPixelStride = planes[1].getPixelStride(); + ImageUtils.convertYUV420ToARGB8888( yuvBytes[0], yuvBytes[1], yuvBytes[2], - rgbBytes, previewWidth, previewHeight, yRowStride, uvRowStride, uvPixelStride, - false); + rgbBytes); image.close(); } catch (final Exception e) { diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java index f3e7114335f..1dcf9f55efe 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java @@ -41,10 +41,6 @@ import org.tensorflow.demo.env.Logger; public class TensorFlowMultiBoxDetector implements Classifier { private static final Logger LOGGER = new Logger(); - static { - System.loadLibrary("tensorflow_demo"); - } - // Only return this many results with at least this confidence. private static final int MAX_RESULTS = Integer.MAX_VALUE; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java index 174723071da..b7e36a2379d 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java @@ -31,10 +31,6 @@ import org.tensorflow.demo.env.SplitTimer; public class TensorFlowYoloDetector implements Classifier { private static final Logger LOGGER = new Logger(); - static { - System.loadLibrary("tensorflow_demo"); - } - // Only return this many results with at least this confidence. private static final int MAX_RESULTS = 5; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java index db929e5e087..5f2ff9164cc 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java @@ -27,6 +27,14 @@ import java.io.FileOutputStream; public class ImageUtils { @SuppressWarnings("unused") private static final Logger LOGGER = new Logger(); + + static { + try { + System.loadLibrary("tensorflow_demo"); + } catch (UnsatisfiedLinkError e) { + LOGGER.w("Native library not found, native RGB -> YUV conversion may be unavailable."); + } + } /** * Utility method to compute the allocated size in bytes of a YUV420SP image @@ -83,10 +91,84 @@ public class ImageUtils { } } + // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges + // are normalized to eight bits. + static final int kMaxChannelValue = 262143; + + // Always prefer the native implementation if available. + private static boolean useNativeConversion = true; + + public static void convertYUV420ToARGB8888( + byte[] yData, + byte[] uData, + byte[] vData, + int width, + int height, + int yRowStride, + int uvRowStride, + int uvPixelStride, + int[] out) { + if (useNativeConversion) { + try { + convertYUV420ToARGB8888( + yData, uData, vData, out, width, height, yRowStride, uvRowStride, uvPixelStride, false); + return; + } catch (UnsatisfiedLinkError e) { + LOGGER.w("Native YUV -> RGB implementation not found, falling back to Java implementation"); + useNativeConversion = false; + } + } + + int i = 0; + for (int y = 0; y < height; y++) { + int pY = yRowStride * y; + int uv_row_start = uvRowStride * (y >> 1); + int pUV = uv_row_start; + int pV = uv_row_start; + + for (int x = 0; x < width; x++) { + int uv_offset = pUV + (x >> 1) * uvPixelStride; + out[i++] = + YUV2RGB( + convertByteToInt(yData, pY + x), + convertByteToInt(uData, uv_offset), + convertByteToInt(vData, uv_offset)); + } + } + } + + private static int convertByteToInt(byte[] arr, int pos) { + return arr[pos] & 0xFF; + } + + private static int YUV2RGB(int nY, int nU, int nV) { + nY -= 16; + nU -= 128; + nV -= 128; + if (nY < 0) nY = 0; + + // This is the floating point equivalent. We do the conversion in integer + // because some Android devices do not have floating point in hardware. + // nR = (int)(1.164 * nY + 2.018 * nU); + // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); + // nB = (int)(1.164 * nY + 1.596 * nV); + + final int foo = 1192 * nY; + int nR = foo + 1634 * nV; + int nG = foo - 833 * nV - 400 * nU; + int nB = foo + 2066 * nU; + + nR = Math.min(kMaxChannelValue, Math.max(0, nR)); + nG = Math.min(kMaxChannelValue, Math.max(0, nG)); + nB = Math.min(kMaxChannelValue, Math.max(0, nB)); + + return 0xff000000 | ((nR << 6) & 0x00ff0000) | ((nG >> 2) & 0x0000FF00) | ((nB >> 10) & 0xff); + } + /** - * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width - * and height. The input and output must already be allocated and non-null. - * For efficiency, no error checking is performed. + * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width and height. The + * input and output must already be allocated and non-null. For efficiency, no error checking is + * performed. * * @param input The array of YUV 4:2:0 input data. * @param output A pre-allocated array for the ARGB 8:8:8:8 output data. From 7e0b20510f25c6fb12ee8c055e32fb575f588abb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 12:49:13 -0800 Subject: [PATCH 26/43] Add sparse_recall_at_top_k which takes top-k class indices instead of class logits. Change: 155121560 --- .../contrib/metrics/python/ops/metric_ops.py | 84 ++++++ .../metrics/python/ops/metric_ops_test.py | 275 ++++++++++++++++++ tensorflow/python/ops/metrics_impl.py | 69 ++++- 3 files changed, 427 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index d57203c042d..727cdd9597a 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1338,6 +1338,87 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, name=name_scope) +def sparse_recall_at_top_k(labels, + top_k_predictions, + class_id=None, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes recall@k of top-k predictions with respect to sparse labels. + + If `class_id` is specified, we calculate recall by considering only the + entries in the batch for which `class_id` is in the label, and computing + the fraction of them for which `class_id` is in the top-k `predictions`. + If `class_id` is not specified, we'll calculate recall as how often on + average a class among the labels of a batch entry is in the top-k + `predictions`. + + `sparse_recall_at_top_k` creates two local variables, `true_positive_at_` + and `false_negative_at_`, that are used to compute the recall_at_k + frequency. This frequency is ultimately returned as `recall_at_`: an + idempotent operation that simply divides `true_positive_at_` by total + (`true_positive_at_` + `false_negative_at_`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `recall_at_`. Set operations applied to `top_k` and `labels` calculate the + true positives and false negatives weighted by `weights`. Then `update_op` + increments `true_positive_at_` and `false_negative_at_` using these + values. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: `int64` `Tensor` or `SparseTensor` with shape + [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of + target classes for the associated prediction. Commonly, N=1 and `labels` + has shape [batch_size, num_labels]. [D1, ... DN] must match + `top_k_predictions`. Values should be in range [0, num_classes), where + num_classes is the last dimension of `predictions`. Values outside this + range always count towards `false_negative_at_`. + top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where + N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k]. + The final dimension contains the indices of top-k labels. [D1, ... DN] + must match `labels`. + class_id: Integer class ID for which we want binary metrics. This should be + in range [0, num_classes), where num_classes is the last dimension of + `predictions`. If class_id is outside this range, the method returns NAN. + weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of + `labels`. If the latter, it must be broadcastable to `labels` (i.e., all + dimensions must be either `1`, or the same as the corresponding `labels` + dimension). + metrics_collections: An optional list of collections that values should + be added to. + updates_collections: An optional list of collections that updates should + be added to. + name: Name of new update operation, and namespace for other dependent ops. + + Returns: + recall: Scalar `float64` `Tensor` with the value of `true_positives` divided + by the sum of `true_positives` and `false_negatives`. + update_op: `Operation` that increments `true_positives` and + `false_negatives` variables appropriately, and whose value matches + `recall`. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match + `predictions`, or if either `metrics_collections` or `updates_collections` + are not a list or tuple. + """ + default_name = _at_k_name('recall', class_id=class_id) + with ops.name_scope(name, default_name, (top_k_predictions, labels, + weights)) as name_scope: + return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access + labels=labels, + predictions_idx=top_k_predictions, + class_id=class_id, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name_scope) + + def streaming_sparse_average_precision_at_k(predictions, labels, k, @@ -2288,6 +2369,7 @@ def _remove_squeezable_dimensions(predictions, labels, weights): __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', + 'sparse_recall_at_top_k', 'streaming_accuracy', 'streaming_auc', 'streaming_false_negatives', @@ -2310,7 +2392,9 @@ __all__ = [ 'streaming_root_mean_squared_error', 'streaming_sensitivity_at_specificity', 'streaming_sparse_average_precision_at_k', + 'streaming_sparse_average_precision_at_top_k', 'streaming_sparse_precision_at_k', + 'streaming_sparse_precision_at_top_k', 'streaming_sparse_recall_at_k', 'streaming_specificity_at_sensitivity', 'streaming_true_negatives', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index b960e1310ec..f42e974e238 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2958,8 +2958,38 @@ class StreamingSparseRecallTest(test.TestCase): self.assertEqual(expected, update.eval()) self.assertEqual(expected, metric.eval()) + def _test_sparse_recall_at_top_k(self, + labels, + top_k_predictions, + expected, + class_id=None, + weights=None): + with ops.Graph().as_default() as g, self.test_session(g): + if weights is not None: + weights = constant_op.constant(weights, dtypes_lib.float32) + metric, update = metric_ops.sparse_recall_at_top_k( + labels=labels, + top_k_predictions=constant_op.constant(top_k_predictions, + dtypes_lib.int32), + class_id=class_id, + weights=weights) + + # Fails without initialized vars. + self.assertRaises(errors_impl.OpError, metric.eval) + self.assertRaises(errors_impl.OpError, update.eval) + variables.variables_initializer(variables.local_variables()).run() + + # Run per-step op and assert expected values. + if math.isnan(expected): + self.assertTrue(math.isnan(update.eval())) + self.assertTrue(math.isnan(metric.eval())) + else: + self.assertEqual(expected, update.eval()) + self.assertEqual(expected, metric.eval()) + def test_one_label_at_k1_nan(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2970,9 +3000,12 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (-1, 0, 1, 4): self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_one_label_at_k1_no_predictions(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2981,9 +3014,12 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 0 predictions. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.0, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0, class_id=2) def test_one_label_at_k1(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -2992,13 +3028,18 @@ class StreamingSparseRecallTest(test.TestCase): # Class 3: 1 label, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 1, class_id=3) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, class_id=3) # All classes: 2 labels, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2) def test_one_label_at_k1_weighted(self): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + top_k_predictions = [[3], [3]] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 0, 1], [0, 0, 1, 0]]) dense_labels = np.array([[3], [2]], dtype=np.int64) @@ -3007,6 +3048,8 @@ class StreamingSparseRecallTest(test.TestCase): # Class 3: 1 label, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3014,6 +3057,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3021,6 +3070,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3028,6 +3083,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=3, weights=(0.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3035,6 +3096,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=3, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3042,6 +3109,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3049,6 +3122,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1, class_id=3, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3056,6 +3135,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2, class_id=3, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3063,6 +3148,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=3.0 / 3, class_id=3, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3070,6 +3161,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.3 / 0.3, class_id=3, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3077,32 +3174,70 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.6 / 0.6, class_id=3, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) # All classes: 2 labels, 2 predictions, 1 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=NAN, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=(0.0,)) self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) def test_three_labels_at_k5_nan(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3112,10 +3247,16 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (0, 3, 4, 6, 9, 10): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_three_labels_at_k5_no_predictions(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3124,10 +3265,16 @@ class StreamingSparseRecallTest(test.TestCase): # Class 8: 1 label, no predictions. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0 / 1, class_id=8) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, class_id=8) def test_three_labels_at_k5(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sparse_labels = _binary_2d_label_to_sparse_value( [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) @@ -3136,23 +3283,35 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=2.0 / 2, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 2, class_id=2) # Class 5: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=1.0 / 1, class_id=5) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, class_id=5) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0 / 1, class_id=7) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, class_id=7) # All classes: 6 labels, 3 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=3.0 / 6) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 6) def test_three_labels_at_k5_some_out_of_range(self): """Tests that labels outside the [0, n_classes) count in denominator.""" predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] + top_k_predictions = [ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ] sp_labels = sparse_tensor.SparseTensorValue( indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1, 3]], @@ -3167,6 +3326,11 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=2.0 / 2, class_id=2) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=2.0 / 2, + class_id=2) # Class 5: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3175,6 +3339,11 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=1.0 / 1, class_id=5) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=1.0 / 1, + class_id=5) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3183,16 +3352,30 @@ class StreamingSparseRecallTest(test.TestCase): k=5, expected=0.0 / 1, class_id=7) + self._test_sparse_recall_at_top_k( + sp_labels, + top_k_predictions, + expected=0.0 / 1, + class_id=7) # All classes: 8 labels, 3 correct. self._test_streaming_sparse_recall_at_k( predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8) + self._test_sparse_recall_at_top_k( + sp_labels, top_k_predictions, expected=3.0 / 8) def test_3d_nan(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] sparse_labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]]) @@ -3207,12 +3390,21 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (0, 3, 4, 6, 9, 10): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=class_id) def test_3d_no_predictions(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] sparse_labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3229,12 +3421,21 @@ class StreamingSparseRecallTest(test.TestCase): for class_id in (1, 8): self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=0.0, class_id=class_id) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0, class_id=class_id) def test_3d(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3244,24 +3445,39 @@ class StreamingSparseRecallTest(test.TestCase): # Class 2: 4 labels, all correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=4.0 / 4, class_id=2) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=4.0 / 4, class_id=2) # Class 5: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=2.0 / 2, class_id=5) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 2, class_id=5) # Class 7: 2 labels, 1 incorrect. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=1.0 / 2, class_id=7) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, class_id=7) # All classes: 12 labels, 7 correct. self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=7.0 / 12) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=7.0 / 12) def test_3d_ignore_all(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3276,6 +3492,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=class_id, weights=[[0], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=class_id, + weights=[[0], [0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, @@ -3283,16 +3505,33 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=class_id, weights=[[0, 0], [0, 0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=class_id, + weights=[[0, 0], [0, 0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, weights=[[0], [0]]) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=[[0], [0]]) self._test_streaming_sparse_recall_at_k( predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]]) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=[[0, 0], [0, 0]]) def test_3d_ignore_some(self): predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] + top_k_predictions = [[ + [9, 4, 6, 2, 0], + [5, 7, 2, 9, 6], + ], [ + [5, 7, 2, 9, 6], + [9, 4, 6, 2, 0], + ]] labels = _binary_3d_label_to_sparse_value( [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], @@ -3307,6 +3546,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2.0, class_id=2, weights=[[1], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2.0, + class_id=2, + weights=[[1], [0]]) # Class 2: 2 labels, both correct. self._test_streaming_sparse_recall_at_k( @@ -3316,6 +3561,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=2.0 / 2.0, class_id=2, weights=[[0], [1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2.0, + class_id=2, + weights=[[0], [1]]) # Class 7: 1 label, correct. self._test_streaming_sparse_recall_at_k( @@ -3325,6 +3576,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 1.0, class_id=7, weights=[[0], [1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1.0, + class_id=7, + weights=[[0], [1]]) # Class 7: 1 label, incorrect. self._test_streaming_sparse_recall_at_k( @@ -3334,6 +3591,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=0.0 / 1.0, class_id=7, weights=[[1], [0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.0 / 1.0, + class_id=7, + weights=[[1], [0]]) # Class 7: 2 labels, 1 correct. self._test_streaming_sparse_recall_at_k( @@ -3343,6 +3606,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=1.0 / 2.0, class_id=7, weights=[[1, 0], [1, 0]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 2.0, + class_id=7, + weights=[[1, 0], [1, 0]]) # Class 7: No labels. self._test_streaming_sparse_recall_at_k( @@ -3352,6 +3621,12 @@ class StreamingSparseRecallTest(test.TestCase): expected=NAN, class_id=7, weights=[[0, 1], [0, 1]]) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=7, + weights=[[0, 1], [0, 1]]) def test_sparse_tensor_value(self): predictions = [[0.1, 0.3, 0.2, 0.4], diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 4dc8e702ca3..28ed3af9d73 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -1924,7 +1924,74 @@ def recall_at_k(labels, labels = _maybe_expand_labels(labels, predictions) _, top_k_idx = nn.top_k(predictions, k) - top_k_idx = math_ops.to_int64(top_k_idx) + return _sparse_recall_at_top_k( + labels=labels, + predictions_idx=top_k_idx, + k=k, + class_id=class_id, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=scope) + + +def _sparse_recall_at_top_k(labels, + predictions_idx, + k=None, + class_id=None, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes recall@k of top-k predictions with respect to sparse labels. + + Differs from `recall_at_k` in that predictions must be in the form of top `k` + class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` + for more details. + + Args: + labels: `int64` `Tensor` or `SparseTensor` with shape + [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies + num_labels=1. N >= 1 and num_labels is the number of target classes for + the associated prediction. Commonly, N=1 and `labels` has shape + [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values + should be in range [0, num_classes), where num_classes is the last + dimension of `predictions`. Values outside this range always count + towards `false_negative_at_`. + predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. + Commonly, N=1 and predictions has shape [batch size, k]. The final + dimension contains the top `k` predicted class indices. [D1, ... DN] must + match `labels`. + k: Integer, k for @k metric. + class_id: Integer class ID for which we want binary metrics. This should be + in range [0, num_classes), where num_classes is the last dimension of + `predictions`. If class_id is outside this range, the method returns NAN. + weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of + `labels`. If the latter, it must be broadcastable to `labels` (i.e., all + dimensions must be either `1`, or the same as the corresponding `labels` + dimension). + metrics_collections: An optional list of collections that values should + be added to. + updates_collections: An optional list of collections that updates should + be added to. + name: Name of new update operation, and namespace for other dependent ops. + + Returns: + recall: Scalar `float64` `Tensor` with the value of `true_positives` divided + by the sum of `true_positives` and `false_negatives`. + update_op: `Operation` that increments `true_positives` and + `false_negatives` variables appropriately, and whose value matches + `recall`. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match + `predictions`, or if either `metrics_collections` or `updates_collections` + are not a list or tuple. + """ + with ops.name_scope(name, + _at_k_name('recall', k, class_id=class_id), + (predictions_idx, labels, weights)) as scope: + top_k_idx = math_ops.to_int64(predictions_idx) tp, tp_update = _streaming_sparse_true_positive_at_k( predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id, weights=weights) From a90d9a425aafbb4152cad8a1b4d5a4ca9090b46e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 12:50:49 -0800 Subject: [PATCH 27/43] [XLA] Adapt to interface changes in llvm r302060. Change: 155121754 --- tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 3 ++- tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 1c704fd1ee7..1e34de9e4bd 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -201,7 +201,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name, if (&argument == retval) { continue; } - compute_function_->setDoesNotAlias(argument.getArgNo() + 1); + compute_function_->addAttribute(argument.getArgNo() + 1, + llvm::Attribute::NoAlias); } ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 04babcca0c8..e52e55a1a81 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -196,7 +196,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ir_emitter_context_->buffer_assignment().GetTempAllocation()) { kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size()); } - kernel->setDoesNotAlias(temp_buffer_arg_no + 1); + kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX // treats it as a CUDA kernel. From c36a71d962cb49ce25d8d2173587738692742bb6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 12:55:13 -0800 Subject: [PATCH 28/43] Go: Update generated wrapper functions for TensorFlow ops. Change: 155122342 --- tensorflow/go/op/wrappers.go | 3558 +++++++++++++++++----------------- 1 file changed, 1779 insertions(+), 1779 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index c63be8bc5ee..eb4789a1829 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -3522,256 +3522,6 @@ func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Opera return scope.AddOperation(opspec) } -// Table initializer that takes two tensors for keys and values respectively. -// -// Arguments: -// table_handle: Handle to a table which will be initialized. -// keys: Keys of type Tkey. -// values: Values of type Tval. -// -// Returns the created operation. -func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InitializeTableV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. -type MutableHashTableV2Attr func(optionalAttr) - -// MutableHashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableV2Container(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates an empty hash table. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MutableHashTableV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// HashTableV2Attr is an optional argument to HashTableV2. -type HashTableV2Attr func(optionalAttr) - -// HashTableV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func HashTableV2Container(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// HashTableV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func HashTableV2SharedName(value string) HashTableV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// -// value: If true and shared_name is empty, the table is shared -// using the node name. -// If not specified, defaults to false -func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// Creates a non-initialized hash table. -// -// This op creates a hash table, specifying the type of its keys and values. -// Before using the table you will have to initialize it. After initialization the -// table will be immutable. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "HashTableV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Replaces the contents of the table with the specified keys and values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. -// -// Returns the created operation. -func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableImportV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// Outputs all keys and values in the table. -// -// Arguments: -// table_handle: Handle to the table. -// -// -// -// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. -func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} - opspec := tf.OpSpec{ - Type: "LookupTableExportV2", - Input: []tf.Input{ - table_handle, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Updates the table to associates keys with values. -// -// The tensor `keys` must be of the same type as the keys of the table. -// The tensor `values` must be of the type of the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. -// -// Returns the created operation. -func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableInsertV2", - Input: []tf.Input{ - table_handle, keys, values, - }, - } - return scope.AddOperation(opspec) -} - -// Looks up keys in a table, outputs the corresponding values. -// -// The tensor `keys` must of the same type as the keys of the table. -// The output `values` is of the type of the table values. -// -// The scalar `default_value` is the value output for keys not present in the -// table. It must also be of the same type as the table values. -// -// Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// -// -// Returns Same shape as `keys`. Values found in the table, or `default_values` -// for missing keys. -func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableFindV2", - Input: []tf.Input{ - table_handle, keys, default_value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs. type FakeQuantWithMinMaxArgsAttr func(optionalAttr) @@ -5404,6 +5154,435 @@ func ExtractGlimpse(scope *Scope, input tf.Output, size tf.Output, offsets tf.Ou return op.Output(0) } +// Draw bounding boxes on a batch of images. +// +// Outputs a copy of `images` but draws on top of the pixels zero or more bounding +// boxes specified by the locations in `boxes`. The coordinates of the each +// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. +// +// For example, if an image is 100 x 200 pixels and the bounding box is +// `[0.1, 0.2, 0.5, 0.9]`, the bottom-left and upper-right coordinates of the +// bounding box will be `(10, 40)` to `(50, 180)`. +// +// Parts of the bounding box may fall outside the image. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. +// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding +// boxes. +// +// Returns 4-D with the same shape as `images`. The batch of input images with +// bounding boxes drawn on the images. +func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DrawBoundingBoxes", + Input: []tf.Input{ + images, boxes, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Convert one or more images from HSV to RGB. +// +// Outputs a tensor of the same shape as the `images` tensor, containing the RGB +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// See `rgb_to_hsv` for a description of the HSV encoding. +// +// Arguments: +// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// +// Returns `images` converted to RGB. +func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "HSVToRGB", + Input: []tf.Input{ + images, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Decode the first frame of a GIF-encoded image to a uint8 tensor. +// +// GIF with frame or transparency compression are not supported +// convert animated GIF from compressed to uncompressed by: +// +// convert $src.gif -coalesce $dst.gif +// +// Arguments: +// contents: 0-D. The GIF-encoded image. +// +// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order +func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DecodeGif", + Input: []tf.Input{ + contents, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodePngAttr is an optional argument to DecodePng. +type DecodePngAttr func(optionalAttr) + +// DecodePngChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodePngChannels(value int64) DecodePngAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodePngDtype sets the optional dtype attribute to value. +// If not specified, defaults to DT_UINT8 +func DecodePngDtype(value tf.DataType) DecodePngAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Decode a PNG-encoded image to a uint8 or uint16 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the PNG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// * 4: output an RGBA image. +// +// If needed, the PNG-encoded image is transformed to match the requested number +// of color channels. +// +// Arguments: +// contents: 0-D. The PNG-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`. +func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodePng", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Adjust the contrast of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are +// interpreted as `[height, width, channels]`. The other dimensions only +// represent a collection of images, such as `[batch, height, width, channels].` +// +// Contrast is adjusted independently for each channel of each image. +// +// For each channel, the Op first computes the mean of the image pixels in the +// channel and then adjusts each component of each pixel to +// `(x - mean) * contrast_factor + mean`. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// contrast_factor: A float multiplier for adjusting contrast. +// +// Returns The contrast-adjusted image or images. +func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustContrastv2", + Input: []tf.Input{ + images, contrast_factor, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeJpegAttr is an optional argument to DecodeJpeg. +type DecodeJpegAttr func(optionalAttr) + +// DecodeJpegChannels sets the optional channels attribute to value. +// +// value: Number of color channels for the decoded image. +// If not specified, defaults to 0 +func DecodeJpegChannels(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// DecodeJpegRatio sets the optional ratio attribute to value. +// +// value: Downscaling ratio. +// If not specified, defaults to 1 +func DecodeJpegRatio(value int64) DecodeJpegAttr { + return func(m optionalAttr) { + m["ratio"] = value + } +} + +// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. +// +// value: If true use a slower but nicer upscaling of the +// chroma planes (yuv420/422 only). +// If not specified, defaults to true +func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["fancy_upscaling"] = value + } +} + +// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. +// +// value: If true try to recover an image from truncated input. +// If not specified, defaults to false +func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { + return func(m optionalAttr) { + m["try_recover_truncated"] = value + } +} + +// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. +// +// value: The minimum required fraction of lines before a truncated +// input is accepted. +// If not specified, defaults to 1 +func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { + return func(m optionalAttr) { + m["acceptable_fraction"] = value + } +} + +// DecodeJpegDctMethod sets the optional dct_method attribute to value. +// +// value: string specifying a hint about the algorithm used for +// decompression. Defaults to "" which maps to a system-specific +// default. Currently valid values are ["INTEGER_FAST", +// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal +// jpeg library changes to a version that does not have that specific +// option.) +// If not specified, defaults to "" +func DecodeJpegDctMethod(value string) DecodeJpegAttr { + return func(m optionalAttr) { + m["dct_method"] = value + } +} + +// Decode a JPEG-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the JPEG-encoded image. +// * 1: output a grayscale image. +// * 3: output an RGB image. +// +// If needed, the JPEG-encoded image is transformed to match the requested number +// of color channels. +// +// The attr `ratio` allows downscaling the image by an integer factor during +// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +// downscaling the image later. +// +// Arguments: +// contents: 0-D. The JPEG-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`.. +func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeJpeg", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. +type ResizeNearestNeighborGradAttr func(optionalAttr) + +// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, rescale grads by (orig_height - 1) / (height - 1), which +// exactly aligns the 4 corners of grads and original_image. If false, rescale by +// orig_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Computes the gradient of nearest neighbor interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The +// original input size. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients +// with respect to the input image. +func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeNearestNeighborGrad", + Input: []tf.Input{ + grads, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. +type ResizeNearestNeighborAttr func(optionalAttr) + +// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, rescale input by (new_height - 1) / (height - 1), which +// exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Resize `images` to `size` using nearest neighbor interpolation. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeNearestNeighbor", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the set of files matching one or more glob patterns. +// +// Note that this routine only supports wildcard characters in the +// basename portion of the pattern, not in the directory portion. +// +// Arguments: +// pattern: Shell wildcard pattern(s). Scalar or vector of type string. +// +// Returns A vector of matching filenames. +func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatchingFiles", + Input: []tf.Input{ + pattern, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Shuffle dimensions of x according to a permutation. +// +// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: +// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` +func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Transpose", + Input: []tf.Input{ + x, perm, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reads and outputs the entire contents of the input filename. +func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReadFile", + Input: []tf.Input{ + filename, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes softmax cross entropy cost and gradients to backpropagate. // // Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept @@ -6560,6 +6739,95 @@ func Softsign(scope *Scope, features tf.Output) (activations tf.Output) { return op.Output(0) } +// ResizeBilinearAttr is an optional argument to ResizeBilinear. +type ResizeBilinearAttr func(optionalAttr) + +// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, rescale input by (new_height - 1) / (height - 1), which +// exactly aligns the 4 corners of images and resized images. If false, rescale +// by new_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Resize `images` to `size` using bilinear interpolation. +// +// Input images can be of different types but output images are always float. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBilinear", + Input: []tf.Input{ + images, size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ProdAttr is an optional argument to Prod. +type ProdAttr func(optionalAttr) + +// ProdKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func ProdKeepDims(value bool) ProdAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the product of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `reduction_indices`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// reduction_indices: The dimensions to reduce. +// +// Returns The reduced tensor. +func Prod(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...ProdAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Prod", + Input: []tf.Input{ + input, reduction_indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // DepthwiseConv2dNativeAttr is an optional argument to DepthwiseConv2dNative. type DepthwiseConv2dNativeAttr func(optionalAttr) @@ -6770,6 +7038,181 @@ func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) return op.Output(0) } +// EncodeJpegAttr is an optional argument to EncodeJpeg. +type EncodeJpegAttr func(optionalAttr) + +// EncodeJpegFormat sets the optional format attribute to value. +// +// value: Per pixel image format. +// If not specified, defaults to "" +func EncodeJpegFormat(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["format"] = value + } +} + +// EncodeJpegQuality sets the optional quality attribute to value. +// +// value: Quality of the compression from 0 to 100 (higher is better and slower). +// If not specified, defaults to 95 +func EncodeJpegQuality(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["quality"] = value + } +} + +// EncodeJpegProgressive sets the optional progressive attribute to value. +// +// value: If True, create a JPEG that loads progressively (coarse to fine). +// If not specified, defaults to false +func EncodeJpegProgressive(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["progressive"] = value + } +} + +// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. +// +// value: If True, spend CPU/RAM to reduce size with no quality change. +// If not specified, defaults to false +func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["optimize_size"] = value + } +} + +// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. +// +// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. +// If not specified, defaults to true +func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["chroma_downsampling"] = value + } +} + +// EncodeJpegDensityUnit sets the optional density_unit attribute to value. +// +// value: Unit used to specify `x_density` and `y_density`: +// pixels per inch (`'in'`) or centimeter (`'cm'`). +// If not specified, defaults to "in" +func EncodeJpegDensityUnit(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["density_unit"] = value + } +} + +// EncodeJpegXDensity sets the optional x_density attribute to value. +// +// value: Horizontal pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegXDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["x_density"] = value + } +} + +// EncodeJpegYDensity sets the optional y_density attribute to value. +// +// value: Vertical pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegYDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["y_density"] = value + } +} + +// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. +// +// value: If not empty, embed this XMP metadata in the image header. +// If not specified, defaults to "" +func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { + return func(m optionalAttr) { + m["xmp_metadata"] = value + } +} + +// JPEG-encode an image. +// +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// +// The attr `format` can be used to override the color format of the encoded +// output. Values can be: +// +// * `''`: Use a default format based on the number of channels in the image. +// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension +// of `image` must be 1. +// * `rgb`: Output an RGB JPEG image. The `channels` dimension +// of `image` must be 3. +// +// If `format` is not specified or is the empty string, a default format is picked +// in function of the number of channels in `image`: +// +// * 1: Output a grayscale image. +// * 3: Output an RGB image. +// +// Arguments: +// image: 3-D with shape `[height, width, channels]`. +// +// Returns 0-D. JPEG-encoded image. +func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "EncodeJpeg", + Input: []tf.Input{ + image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Gradients for batch normalization. +// +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// +// This op is deprecated. See `tf.nn.batch_normalization`. +// +// Arguments: +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this Tensor will be multiplied +// with the normalized Tensor. +// backprop: 4D backprop Tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +// +// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. +func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + opspec := tf.OpSpec{ + Type: "BatchNormWithGlobalNormalizationGrad", + Input: []tf.Input{ + t, m, v, gamma, backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) +} + // Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. type Conv2DBackpropInputAttr func(optionalAttr) @@ -7160,6 +7603,51 @@ func SaveSlices(scope *Scope, filename tf.Output, tensor_names tf.Output, shapes return scope.AddOperation(opspec) } +// Writes contents to the file at input filename. Creates file if not existing. +// +// Arguments: +// filename: scalar. The name of the file to which we write the contents. +// contents: scalar. The content to be written to the output file. +// +// Returns the created operation. +func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "WriteFile", + Input: []tf.Input{ + filename, contents, + }, + } + return scope.AddOperation(opspec) +} + +// Computes the Cholesky decomposition of one or more square matrices. +// +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices, with the same constraints as the single matrix Cholesky +// decomposition above. The output is a tensor of the same shape as the input +// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. +// +// Arguments: +// input: Shape is `[..., M, M]`. +// +// Returns Shape is `[..., M, M]`. +func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Cholesky", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns the rank of a tensor. // // This operation returns an integer representing the rank of `input`. @@ -7243,54 +7731,6 @@ func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, opt return output } -// BiasAddGradAttr is an optional argument to BiasAddGrad. -type BiasAddGradAttr func(optionalAttr) - -// BiasAddGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the bias tensor will be added to the last dimension -// of the value tensor. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// The tensor will be added to "in_channels", the third-to-the-last -// dimension. -// If not specified, defaults to "NHWC" -func BiasAddGradDataFormat(value string) BiasAddGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// The backward operation for "BiasAdd" on the "bias" tensor. -// -// It accumulates all the values from out_backprop into the feature dimension. -// For NHWC data format, the feature dimension is the last. For NCHW data format, -// the feature dimension is the third-to-last. -// -// Arguments: -// out_backprop: Any number of dimensions. -// -// Returns 1-D with size the feature dimension of `out_backprop`. -func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "BiasAddGrad", - Input: []tf.Input{ - out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Convert JSON-encoded Example records to binary protocol buffer strings. // // This op translates a tensor containing Example records, encoded using @@ -8024,27 +8464,51 @@ func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output return op.Output(0) } -// Convert one or more images from HSV to RGB. +// EncodePngAttr is an optional argument to EncodePng. +type EncodePngAttr func(optionalAttr) + +// EncodePngCompression sets the optional compression attribute to value. // -// Outputs a tensor of the same shape as the `images` tensor, containing the RGB -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. +// value: Compression level. +// If not specified, defaults to -1 +func EncodePngCompression(value int64) EncodePngAttr { + return func(m optionalAttr) { + m["compression"] = value + } +} + +// PNG-encode an image. // -// See `rgb_to_hsv` for a description of the HSV encoding. +// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` +// where `channels` is: +// +// * 1: for grayscale. +// * 2: for grayscale + alpha. +// * 3: for RGB. +// * 4: for RGBA. +// +// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder +// default or a value from 0 to 9. 9 is the highest compression level, generating +// the smallest output, but is slower. // // Arguments: -// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// image: 3-D with shape `[height, width, channels]`. // -// Returns `images` converted to RGB. -func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { +// Returns 0-D. PNG-encoded image. +func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "HSVToRGB", + Type: "EncodePng", Input: []tf.Input{ - images, + image, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -8976,29 +9440,6 @@ func SparseSparseMinimum(scope *Scope, a_indices tf.Output, a_values tf.Output, return op.Output(0), op.Output(1) } -// Returns the set of files matching one or more glob patterns. -// -// Note that this routine only supports wildcard characters in the -// basename portion of the pattern, not in the directory portion. -// -// Arguments: -// pattern: Shell wildcard pattern(s). Scalar or vector of type string. -// -// Returns A vector of matching filenames. -func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatchingFiles", - Input: []tf.Input{ - pattern, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the gradient of the sigmoid of `x` wrt its input. // // Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and @@ -10269,117 +10710,6 @@ func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max return op.Output(0), op.Output(1), op.Output(2) } -// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. -type InitializeTableFromTextFileV2Attr func(optionalAttr) - -// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. -// -// value: Number of elements of the file, use -1 if unknown. -// If not specified, defaults to -1 -// -// REQUIRES: value >= -1 -func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { - return func(m optionalAttr) { - m["vocab_size"] = value - } -} - -// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. -// -// value: Delimiter to separate fields in a line. -// If not specified, defaults to "\t" -func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { - return func(m optionalAttr) { - m["delimiter"] = value - } -} - -// Initializes a table from a text file. -// -// It inserts one key-value pair into the table for each line of the file. -// The key and value is extracted from the whole line content, elements from the -// split line based on `delimiter` or the line number (starting from zero). -// Where to extract the key and value from a line is specified by `key_index` and -// `value_index`. -// -// - A value of -1 means use the line number(starting from zero), expects `int64`. -// - A value of -2 means use the whole line content, expects `string`. -// - A value >= 0 means use the index (starting at zero) of the split line based -// on `delimiter`. -// -// Arguments: -// table_handle: Handle to a table which will be initialized. -// filename: Filename of a vocabulary text file. -// key_index: Column index in a line to get the table `key` values from. -// value_index: Column index that represents information of a line to get the table -// `value` values from. -// -// Returns the created operation. -func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "InitializeTableFromTextFileV2", - Input: []tf.Input{ - table_handle, filename, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. -type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) - -// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. -// -// value: If True, the subtraction will be protected by a lock; -// otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Sparse update '*var' as FOBOS algorithm with fixed learning rate. -// -// That is for rows we have grad for, we update var as follows: -// prox_v = var - alpha * grad -// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// alpha: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// -// Returns the created operation. -func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyProximalGradientDescent", - Input: []tf.Input{ - var_, alpha, l1, l2, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // Computes rectified linear gradients for a Relu operation. // // Arguments: @@ -10420,51 +10750,6 @@ func ReciprocalGrad(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Computes the Cholesky decomposition of one or more square matrices. -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices, with the same constraints as the single matrix Cholesky -// decomposition above. The output is a tensor of the same shape as the input -// containing the Cholesky decompositions for all input submatrices `[..., :, :]`. -// -// Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[..., M, M]`. -func Cholesky(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Cholesky", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Writes contents to the file at input filename. Creates file if not existing. -// -// Arguments: -// filename: scalar. The name of the file to which we write the contents. -// contents: scalar. The content to be written to the output file. -// -// Returns the created operation. -func WriteFile(scope *Scope, filename tf.Output, contents tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "WriteFile", - Input: []tf.Input{ - filename, contents, - }, - } - return scope.AddOperation(opspec) -} - // Reverses specific dimensions of a tensor. // // NOTE `tf.reverse` has now changed behavior in preparation for 1.0. @@ -10627,6 +10912,35 @@ func IFFT3D(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// Looks up keys in a table, outputs the corresponding values. +// +// The tensor `keys` must of the same type as the keys of the table. +// The output `values` is of the type of the table values. +// +// The scalar `default_value` is the value output for keys not present in the +// table. It must also be of the same type as the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// +// +// Returns Same shape as `keys`. Values found in the table, or `default_values` +// for missing keys. +func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableFindV2", + Input: []tf.Input{ + table_handle, keys, default_value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Given a quantized tensor described by (input, input_min, input_max), outputs a // // range that covers the actual values present in that tensor. This op is @@ -11189,122 +11503,6 @@ func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. -type MutableHashTableOfTensorsV2Attr func(optionalAttr) - -// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// If not specified, defaults to false -func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value - } -} - -// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. -// If not specified, defaults to <> -func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["value_shape"] = value - } -} - -// Creates an empty hash table. -// -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a vector. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. -// -// Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MutableHashTableOfTensorsV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. -type ResourceApplyProximalAdagradAttr func(optionalAttr) - -// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. -// -// accum += grad * grad -// prox_v = var - lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyProximalAdagrad", - Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // TensorArrayV3Attr is an optional argument to TensorArrayV3. type TensorArrayV3Attr func(optionalAttr) @@ -11619,54 +11817,6 @@ func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Ou return op.Output(0) } -// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. -type AvgPool3DGradAttr func(optionalAttr) - -// AvgPool3DGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of average pooling function. -// -// Arguments: -// orig_input_shape: The original input dimensions. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -// -// Returns The backprop for input. -func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "AvgPool3DGrad", - Input: []tf.Input{ - orig_input_shape, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // QuantizedRelu6Attr is an optional argument to QuantizedRelu6. type QuantizedRelu6Attr func(optionalAttr) @@ -12745,6 +12895,54 @@ func Tanh(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. +type AvgPool3DGradAttr func(optionalAttr) + +// AvgPool3DGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of average pooling function. +// +// Arguments: +// orig_input_shape: The original input dimensions. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +// +// Returns The backprop for input. +func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AvgPool3DGrad", + Input: []tf.Input{ + orig_input_shape, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // TextLineReaderV2Attr is an optional argument to TextLineReaderV2. type TextLineReaderV2Attr func(optionalAttr) @@ -13390,39 +13588,6 @@ func ResourceApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_ return scope.AddOperation(opspec) } -// Shuffle dimensions of x according to a permutation. -// -// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: -// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` -func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Transpose", - Input: []tf.Input{ - x, perm, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reads and outputs the entire contents of the input filename. -func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReadFile", - Input: []tf.Input{ - filename, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Output a fact about factorials. func Fact(scope *Scope) (fact tf.Output) { if scope.Err() != nil { @@ -14260,37 +14425,6 @@ func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...Gathe return op.Output(0) } -// Adjust the contrast of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are -// interpreted as `[height, width, channels]`. The other dimensions only -// represent a collection of images, such as `[batch, height, width, channels].` -// -// Contrast is adjusted independently for each channel of each image. -// -// For each channel, the Op first computes the mean of the image pixels in the -// channel and then adjusts each component of each pixel to -// `(x - mean) * contrast_factor + mean`. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// contrast_factor: A float multiplier for adjusting contrast. -// -// Returns The contrast-adjusted image or images. -func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustContrastv2", - Input: []tf.Input{ - images, contrast_factor, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes softsign gradients for a softsign operation. // // Arguments: @@ -14386,31 +14520,6 @@ func Dilation2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64 return op.Output(0) } -// Decode the first frame of a GIF-encoded image to a uint8 tensor. -// -// GIF with frame or transparency compression are not supported -// convert animated GIF from compressed to uncompressed by: -// -// convert $src.gif -coalesce $dst.gif -// -// Arguments: -// contents: 0-D. The GIF-encoded image. -// -// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order -func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DecodeGif", - Input: []tf.Input{ - contents, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // EncodeBase64Attr is an optional argument to EncodeBase64. type EncodeBase64Attr func(optionalAttr) @@ -14672,6 +14781,70 @@ func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment return op.Output(0) } +// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. +type ResizeBilinearGradAttr func(optionalAttr) + +// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. +// +// value: If true, rescale grads by (orig_height - 1) / (height - 1), which +// exactly aligns the 4 corners of grads and original_image. If false, rescale by +// orig_height / height. Treat similarly the width dimension. +// If not specified, defaults to false +func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { + return func(m optionalAttr) { + m["align_corners"] = value + } +} + +// Computes the gradient of bilinear interpolation. +// +// Arguments: +// grads: 4-D with shape `[batch, height, width, channels]`. +// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, +// The image tensor that was resized. +// +// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. +// Gradients with respect to the input image. Input image must have been +// float or double. +func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResizeBilinearGrad", + Input: []tf.Input{ + grads, original_image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the number of elements in the given table. +// +// Arguments: +// table_handle: Handle to the table. +// +// Returns Scalar that contains number of elements in the table. +func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableSizeV2", + Input: []tf.Input{ + table_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Component-wise divides a SparseTensor by a dense Tensor. // // *Limitation*: this Op only broadcasts the dense side to the sparse side, but not @@ -14727,95 +14900,6 @@ func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value return op.Output(0) } -// ProdAttr is an optional argument to Prod. -type ProdAttr func(optionalAttr) - -// ProdKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func ProdKeepDims(value bool) ProdAttr { - return func(m optionalAttr) { - m["keep_dims"] = value - } -} - -// Computes the product of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `reduction_indices`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_indices`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// reduction_indices: The dimensions to reduce. -// -// Returns The reduced tensor. -func Prod(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ...ProdAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Prod", - Input: []tf.Input{ - input, reduction_indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBilinearAttr is an optional argument to ResizeBilinear. -type ResizeBilinearAttr func(optionalAttr) - -// ResizeBilinearAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, rescale input by (new_height - 1) / (height - 1), which -// exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeBilinearAlignCorners(value bool) ResizeBilinearAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Resize `images` to `size` using bilinear interpolation. -// -// Input images can be of different types but output images are always float. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. -// -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBilinearAttr) (resized_images tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBilinear", - Input: []tf.Input{ - images, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the absolute value of a tensor. // // Given a tensor `x`, this operation returns a tensor containing the absolute @@ -14988,6 +15072,108 @@ func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segm return op.Output(0) } +// Converts one or more images from RGB to HSV. +// +// Outputs a tensor of the same shape as the `images` tensor, containing the HSV +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. +// +// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and +// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 +// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. +// +// Arguments: +// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. +// +// Returns `images` converted to HSV. +func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RGBToHSV", + Input: []tf.Input{ + images, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. +type MatrixSolveLsAttr func(optionalAttr) + +// MatrixSolveLsFast sets the optional fast attribute to value. +// If not specified, defaults to true +func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { + return func(m optionalAttr) { + m["fast"] = value + } +} + +// Solves one or more linear least-squares problems. +// +// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`. +// The output is a tensor shape `[..., N, K]` where each output matrix solves +// each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] +// in the least squares sense. +// +// matrix and right-hand sides in the batch: +// +// `matrix`=\\(A \in \Re^{m \times n}\\), +// `rhs`=\\(B \in \Re^{m \times k}\\), +// `output`=\\(X \in \Re^{n \times k}\\), +// `l2_regularizer`=\\(\lambda\\). +// +// If `fast` is `True`, then the solution is computed by solving the normal +// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +// \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares +// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + +// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as +// \\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +// minimum-norm solution to the under-determined linear system, i.e. +// \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to +// \\(A Z = B\\). Notice that the fast path is only numerically stable when +// \\(A\\) is numerically full rank and has a condition number +// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is +// sufficiently large. +// +// If `fast` is `False` an algorithm based on the numerically robust complete +// orthogonal decomposition is used. This computes the minimum-norm +// least-squares solution, even when \\(A\\) is rank deficient. This path is +// typically 6-7 times slower than the fast path. If `fast` is `False` then +// `l2_regularizer` is ignored. +// +// Arguments: +// matrix: Shape is `[..., M, N]`. +// rhs: Shape is `[..., M, K]`. +// l2_regularizer: Scalar tensor. +// +// @compatibility(numpy) +// Equivalent to np.linalg.lstsq +// @end_compatibility +// +// Returns Shape is `[..., N, K]`. +func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixSolveLs", + Input: []tf.Input{ + matrix, rhs, l2_regularizer, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // QuantizedReluXAttr is an optional argument to QuantizedReluX. type QuantizedReluXAttr func(optionalAttr) @@ -15770,6 +15956,30 @@ func TanhGrad(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } +// Outputs all keys and values in the table. +// +// Arguments: +// table_handle: Handle to the table. +// +// +// +// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`. +func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues} + opspec := tf.OpSpec{ + Type: "LookupTableExportV2", + Input: []tf.Input{ + table_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // AddManySparseToTensorsMapAttr is an optional argument to AddManySparseToTensorsMap. type AddManySparseToTensorsMapAttr func(optionalAttr) @@ -15877,6 +16087,153 @@ func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (o return op.Output(0) } +// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. +type TensorArrayGatherV3Attr func(optionalAttr) + +// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Gather specific elements from the TensorArray into output `value`. +// +// All elements selected by `indices` must have the same shape. +// +// Arguments: +// handle: The handle to a TensorArray. +// indices: The locations in the TensorArray from which to read tensor elements. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns All of the elements in the TensorArray, concatenated along a new +// axis (the new dimension 0). +func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayGatherV3", + Input: []tf.Input{ + handle, indices, flow_in, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Disallowed in GraphDef version >= 2. +// +// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead +func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustContrast", + Input: []tf.Input{ + images, contrast_factor, min_value, max_value, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. +type MaxPoolGradGradAttr func(optionalAttr) + +// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes second-order gradients of the maxpooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolGradGrad", + Input: []tf.Input{ + orig_input, orig_output, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 3D real-valued fast Fourier transform. +// +// Computes the 3-dimensional discrete Fourier transform of a real-valued signal +// over the inner-most 3 dimensions of `input`. +// +// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the +// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension +// of `output`: the zero-frequency term, followed by the `fft_length / 2` +// positive-frequency terms. +// +// Arguments: +// input: A float32 tensor. +// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. +// +// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 +// dimensions of `input` are replaced with the their 3D Fourier transform. The +// inner-most dimension contains `fft_length / 2 + 1` unique frequency +// components. +// +// @compatibility(numpy) +// Equivalent to np.fft.rfftn with 3 dimensions. +// @end_compatibility +func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RFFT3D", + Input: []tf.Input{ + input, fft_length, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // UniqueWithCountsAttr is an optional argument to UniqueWithCounts. type UniqueWithCountsAttr func(optionalAttr) @@ -16708,6 +17065,30 @@ func FractionalAvgPool(scope *Scope, value tf.Output, pooling_ratio []float32, o return op.Output(0), op.Output(1), op.Output(2) } +// Updates the table to associates keys with values. +// +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableInsertV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + // Produces the average pool of the input tensor for quantized types. // // Arguments: @@ -16997,41 +17378,6 @@ func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Out return op.Output(0) } -// Draw bounding boxes on a batch of images. -// -// Outputs a copy of `images` but draws on top of the pixels zero or more bounding -// boxes specified by the locations in `boxes`. The coordinates of the each -// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, if an image is 100 x 200 pixels and the bounding box is -// `[0.1, 0.2, 0.5, 0.9]`, the bottom-left and upper-right coordinates of the -// bounding box will be `(10, 40)` to `(50, 180)`. -// -// Parts of the bounding box may fall outside the image. -// -// Arguments: -// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. -// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding -// boxes. -// -// Returns 4-D with the same shape as `images`. The batch of input images with -// bounding boxes drawn on the images. -func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DrawBoundingBoxes", - Input: []tf.Input{ - images, boxes, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns the element-wise max of two SparseTensors. // // Assumes the two SparseTensors have the same shape, i.e., no broadcasting. @@ -17501,28 +17847,6 @@ func Log(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Computes rectified linear 6 gradients for a Relu6 operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Relu6 operation. -// features: The features passed as input to the corresponding Relu6 operation. -// -// Returns The gradients: -// `gradients * (features > 0) * (features < 6)`. -func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Relu6Grad", - Input: []tf.Input{ - gradients, features, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResizeBicubicAttr is an optional argument to ResizeBicubic. type ResizeBicubicAttr func(optionalAttr) @@ -17568,6 +17892,28 @@ func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...R return op.Output(0) } +// Computes rectified linear 6 gradients for a Relu6 operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Relu6 operation. +// features: The features passed as input to the corresponding Relu6 operation. +// +// Returns The gradients: +// `gradients * (features > 0) * (features < 6)`. +func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Relu6Grad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes natural logarithm of (1 + x) element-wise. // // I.e., \\(y = \log_e (1 + x)\\). @@ -17681,181 +18027,6 @@ func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Outp return op.Output(0) } -// Gradients for batch normalization. -// -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() -// -// This op is deprecated. See `tf.nn.batch_normalization`. -// -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this Tensor will be multiplied -// with the normalized Tensor. -// backprop: 4D backprop Tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -// -// Returns 4D backprop tensor for input.1D backprop tensor for mean.1D backprop tensor for variance.1D backprop tensor for beta.1D backprop tensor for gamma. -func BatchNormWithGlobalNormalizationGrad(scope *Scope, t tf.Output, m tf.Output, v tf.Output, gamma tf.Output, backprop tf.Output, variance_epsilon float32, scale_after_normalization bool) (dx tf.Output, dm tf.Output, dv tf.Output, db tf.Output, dg tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} - opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalizationGrad", - Input: []tf.Input{ - t, m, v, gamma, backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - -// EncodeJpegAttr is an optional argument to EncodeJpeg. -type EncodeJpegAttr func(optionalAttr) - -// EncodeJpegFormat sets the optional format attribute to value. -// -// value: Per pixel image format. -// If not specified, defaults to "" -func EncodeJpegFormat(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["format"] = value - } -} - -// EncodeJpegQuality sets the optional quality attribute to value. -// -// value: Quality of the compression from 0 to 100 (higher is better and slower). -// If not specified, defaults to 95 -func EncodeJpegQuality(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["quality"] = value - } -} - -// EncodeJpegProgressive sets the optional progressive attribute to value. -// -// value: If True, create a JPEG that loads progressively (coarse to fine). -// If not specified, defaults to false -func EncodeJpegProgressive(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["progressive"] = value - } -} - -// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. -// -// value: If True, spend CPU/RAM to reduce size with no quality change. -// If not specified, defaults to false -func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["optimize_size"] = value - } -} - -// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. -// -// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. -// If not specified, defaults to true -func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["chroma_downsampling"] = value - } -} - -// EncodeJpegDensityUnit sets the optional density_unit attribute to value. -// -// value: Unit used to specify `x_density` and `y_density`: -// pixels per inch (`'in'`) or centimeter (`'cm'`). -// If not specified, defaults to "in" -func EncodeJpegDensityUnit(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["density_unit"] = value - } -} - -// EncodeJpegXDensity sets the optional x_density attribute to value. -// -// value: Horizontal pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegXDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["x_density"] = value - } -} - -// EncodeJpegYDensity sets the optional y_density attribute to value. -// -// value: Vertical pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegYDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["y_density"] = value - } -} - -// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. -// -// value: If not empty, embed this XMP metadata in the image header. -// If not specified, defaults to "" -func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["xmp_metadata"] = value - } -} - -// JPEG-encode an image. -// -// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. -// -// The attr `format` can be used to override the color format of the encoded -// output. Values can be: -// -// * `''`: Use a default format based on the number of channels in the image. -// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension -// of `image` must be 1. -// * `rgb`: Output an RGB JPEG image. The `channels` dimension -// of `image` must be 3. -// -// If `format` is not specified or is the empty string, a default format is picked -// in function of the number of channels in `image`: -// -// * 1: Output a grayscale image. -// * 3: Output an RGB image. -// -// Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. JPEG-encoded image. -func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodeJpeg", - Input: []tf.Input{ - image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes sin of x element-wise. func Sin(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -18164,6 +18335,117 @@ func ArgMin(scope *Scope, input tf.Output, dimension tf.Output) (output tf.Outpu return op.Output(0) } +// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent. +type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr) + +// ResourceSparseApplyProximalGradientDescentUseLocking sets the optional use_locking attribute to value. +// +// value: If True, the subtraction will be protected by a lock; +// otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyProximalGradientDescentUseLocking(value bool) ResourceSparseApplyProximalGradientDescentAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Sparse update '*var' as FOBOS algorithm with fixed learning rate. +// +// That is for rows we have grad for, we update var as follows: +// prox_v = var - alpha * grad +// var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// alpha: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyProximalGradientDescentAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyProximalGradientDescent", + Input: []tf.Input{ + var_, alpha, l1, l2, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2. +type InitializeTableFromTextFileV2Attr func(optionalAttr) + +// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value. +// +// value: Number of elements of the file, use -1 if unknown. +// If not specified, defaults to -1 +// +// REQUIRES: value >= -1 +func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["vocab_size"] = value + } +} + +// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value. +// +// value: Delimiter to separate fields in a line. +// If not specified, defaults to "\t" +func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr { + return func(m optionalAttr) { + m["delimiter"] = value + } +} + +// Initializes a table from a text file. +// +// It inserts one key-value pair into the table for each line of the file. +// The key and value is extracted from the whole line content, elements from the +// split line based on `delimiter` or the line number (starting from zero). +// Where to extract the key and value from a line is specified by `key_index` and +// `value_index`. +// +// - A value of -1 means use the line number(starting from zero), expects `int64`. +// - A value of -2 means use the whole line content, expects `string`. +// - A value >= 0 means use the index (starting at zero) of the split line based +// on `delimiter`. +// +// Arguments: +// table_handle: Handle to a table which will be initialized. +// filename: Filename of a vocabulary text file. +// key_index: Column index in a line to get the table `key` values from. +// value_index: Column index that represents information of a line to get the table +// `value` values from. +// +// Returns the created operation. +func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "InitializeTableFromTextFileV2", + Input: []tf.Input{ + table_handle, filename, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // Computes atan of x element-wise. func Atan(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -18628,33 +18910,36 @@ func Less(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. -type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) +// BiasAddGradAttr is an optional argument to BiasAddGrad. +type BiasAddGradAttr func(optionalAttr) -// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value. +// BiasAddGradDataFormat sets the optional data_format attribute to value. // -// value: The bitwidth of the quantization; between 2 and 8, inclusive. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the bias tensor will be added to the last dimension +// of the value tensor. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// The tensor will be added to "in_channels", the third-to-the-last +// dimension. +// If not specified, defaults to "NHWC" +func BiasAddGradDataFormat(value string) BiasAddGradAttr { return func(m optionalAttr) { - m["num_bits"] = value + m["data_format"] = value } } -// Compute gradients for a FakeQuantWithMinMaxVars operation. +// The backward operation for "BiasAdd" on the "bias" tensor. +// +// It accumulates all the values from out_backprop into the feature dimension. +// For NHWC data format, the feature dimension is the last. For NCHW data format, +// the feature dimension is the third-to-last. // // Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. -// min, max: Quantization interval, scalar floats. +// out_backprop: Any number of dimensions. // -// -// -// Returns Backpropagated gradients w.r.t. inputs: -// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter: -// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter: -// `sum(gradients * (inputs > max))`. -func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { +// Returns 1-D with size the feature dimension of `out_backprop`. +func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -18663,31 +18948,13 @@ func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs t a(attrs) } opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVarsGradient", + Type: "BiasAddGrad", Input: []tf.Input{ - gradients, inputs, min, max, + out_backprop, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// Returns the min of x and y (i.e. x < y ? x : y) element-wise. -// -// *NOTE*: `Minimum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Minimum", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) return op.Output(0) } @@ -19996,65 +20263,6 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp return op.Output(0), op.Output(1), op.Output(2) } -// DecodePngAttr is an optional argument to DecodePng. -type DecodePngAttr func(optionalAttr) - -// DecodePngChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodePngChannels(value int64) DecodePngAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodePngDtype sets the optional dtype attribute to value. -// If not specified, defaults to DT_UINT8 -func DecodePngDtype(value tf.DataType) DecodePngAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Decode a PNG-encoded image to a uint8 or uint16 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the PNG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// * 4: output an RGBA image. -// -// If needed, the PNG-encoded image is transformed to match the requested number -// of color channels. -// -// Arguments: -// contents: 0-D. The PNG-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`. -func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodePng", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // AudioSummaryV2Attr is an optional argument to AudioSummaryV2. type AudioSummaryV2Attr func(optionalAttr) @@ -20219,31 +20427,188 @@ func AudioSummary(scope *Scope, tag tf.Output, tensor tf.Output, sample_rate flo return op.Output(0) } -// ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. -type ResizeNearestNeighborAttr func(optionalAttr) - -// ResizeNearestNeighborAlignCorners sets the optional align_corners attribute to value. +// Replaces the contents of the table with the specified keys and values. // -// value: If true, rescale input by (new_height - 1) / (height - 1), which -// exactly aligns the 4 corners of images and resized images. If false, rescale -// by new_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeNearestNeighborAlignCorners(value bool) ResizeNearestNeighborAttr { +// The tensor `keys` must be of the same type as the keys of the table. +// The tensor `values` must be of the type of the table values. +// +// Arguments: +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableImportV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// HashTableV2Attr is an optional argument to HashTableV2. +type HashTableV2Attr func(optionalAttr) + +// HashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func HashTableV2Container(value string) HashTableV2Attr { return func(m optionalAttr) { - m["align_corners"] = value + m["container"] = value } } -// Resize `images` to `size` using nearest neighbor interpolation. +// HashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func HashTableV2SharedName(value string) HashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates a non-initialized hash table. +// +// This op creates a hash table, specifying the type of its keys and values. +// Before using the table you will have to initialize it. After initialization the +// table will be immutable. // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeNearestNeighborAttr) (resized_images tf.Output) { +// Returns Handle to a table. +func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "HashTableV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MutableHashTableV2Attr is an optional argument to MutableHashTableV2. +type MutableHashTableV2Attr func(optionalAttr) + +// MutableHashTableV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableV2Container(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// +// value: If true and shared_name is empty, the table is shared +// using the node name. +// If not specified, defaults to false +func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. +type ResourceApplyProximalAdagradAttr func(optionalAttr) + +// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. +// +// accum += grad * grad +// prox_v = var - lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -20252,12 +20617,164 @@ func ResizeNearestNeighbor(scope *Scope, images tf.Output, size tf.Output, optio a(attrs) } opspec := tf.OpSpec{ - Type: "ResizeNearestNeighbor", + Type: "ResourceApplyProximalAdagrad", Input: []tf.Input{ - images, size, + var_, accum, lr, l1, l2, grad, }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. +type MutableHashTableOfTensorsV2Attr func(optionalAttr) + +// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// If not specified, defaults to <> +func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a vector. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableOfTensorsV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Table initializer that takes two tensors for keys and values respectively. +// +// Arguments: +// table_handle: Handle to a table which will be initialized. +// keys: Keys of type Tkey. +// values: Values of type Tval. +// +// Returns the created operation. +func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "InitializeTableV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. +type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value. +// +// value: The bitwidth of the quantization; between 2 and 8, inclusive. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxVars operation. +// +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. +// min, max: Quantization interval, scalar floats. +// +// +// +// Returns Backpropagated gradients w.r.t. inputs: +// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter: +// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter: +// `sum(gradients * (inputs > max))`. +func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxVarsGradient", + Input: []tf.Input{ + gradients, inputs, min, max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Returns the min of x and y (i.e. x < y ? x : y) element-wise. +// +// *NOTE*: `Minimum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Minimum", + Input: []tf.Input{ + x, y, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } @@ -20385,6 +20902,84 @@ func TFRecordReaderV2(scope *Scope, optional ...TFRecordReaderV2Attr) (reader_ha return op.Output(0) } +// Adjust the saturation of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A scale is then applied all the saturation +// values, and then remapped back to RGB colorspace. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// scale: A float scale to add to the saturation. +// +// Returns The hue-adjusted image or images. +func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustSaturation", + Input: []tf.Input{ + images, scale, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. +type SelfAdjointEigV2Attr func(optionalAttr) + +// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. +// +// value: If `True` then eigenvectors will be computed and returned in `v`. +// Otherwise, only the eigenvalues will be computed. +// If not specified, defaults to true +func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { + return func(m optionalAttr) { + m["compute_v"] = value + } +} + +// Computes the eigen decomposition of one or more square self-adjoint matrices. +// +// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in +// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. +// +// ```prettyprint +// # a is a tensor. +// # e is a tensor of eigenvalues. +// # v is a tensor of eigenvectors. +// e, v = self_adjoint_eig(a) +// e = self_adjoint_eig(a, compute_v=False) +// ``` +// +// Arguments: +// input: `Tensor` input of shape `[N, N]`. +// +// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. +func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SelfAdjointEigV2", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // MatrixSolveAttr is an optional argument to MatrixSolve. type MatrixSolveAttr func(optionalAttr) @@ -21033,371 +21628,6 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O return op.Output(0), op.Output(1) } -// Computes the number of elements in the given table. -// -// Arguments: -// table_handle: Handle to the table. -// -// Returns Scalar that contains number of elements in the table. -func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LookupTableSizeV2", - Input: []tf.Input{ - table_handle, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad. -type ResizeBilinearGradAttr func(optionalAttr) - -// ResizeBilinearGradAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, rescale grads by (orig_height - 1) / (height - 1), which -// exactly aligns the 4 corners of grads and original_image. If false, rescale by -// orig_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeBilinearGradAlignCorners(value bool) ResizeBilinearGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Computes the gradient of bilinear interpolation. -// -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// original_image: 4-D with shape `[batch, orig_height, orig_width, channels]`, -// The image tensor that was resized. -// -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. -// Gradients with respect to the input image. Input image must have been -// float or double. -func ResizeBilinearGrad(scope *Scope, grads tf.Output, original_image tf.Output, optional ...ResizeBilinearGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeBilinearGrad", - Input: []tf.Input{ - grads, original_image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeNearestNeighborGradAttr is an optional argument to ResizeNearestNeighborGrad. -type ResizeNearestNeighborGradAttr func(optionalAttr) - -// ResizeNearestNeighborGradAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, rescale grads by (orig_height - 1) / (height - 1), which -// exactly aligns the 4 corners of grads and original_image. If false, rescale by -// orig_height / height. Treat similarly the width dimension. -// If not specified, defaults to false -func ResizeNearestNeighborGradAlignCorners(value bool) ResizeNearestNeighborGradAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } -} - -// Computes the gradient of nearest neighbor interpolation. -// -// Arguments: -// grads: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The -// original input size. -// -// Returns 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients -// with respect to the input image. -func ResizeNearestNeighborGrad(scope *Scope, grads tf.Output, size tf.Output, optional ...ResizeNearestNeighborGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResizeNearestNeighborGrad", - Input: []tf.Input{ - grads, size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DecodeJpegAttr is an optional argument to DecodeJpeg. -type DecodeJpegAttr func(optionalAttr) - -// DecodeJpegChannels sets the optional channels attribute to value. -// -// value: Number of color channels for the decoded image. -// If not specified, defaults to 0 -func DecodeJpegChannels(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// DecodeJpegRatio sets the optional ratio attribute to value. -// -// value: Downscaling ratio. -// If not specified, defaults to 1 -func DecodeJpegRatio(value int64) DecodeJpegAttr { - return func(m optionalAttr) { - m["ratio"] = value - } -} - -// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value. -// -// value: If true use a slower but nicer upscaling of the -// chroma planes (yuv420/422 only). -// If not specified, defaults to true -func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["fancy_upscaling"] = value - } -} - -// DecodeJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value. -// -// value: If true try to recover an image from truncated input. -// If not specified, defaults to false -func DecodeJpegTryRecoverTruncated(value bool) DecodeJpegAttr { - return func(m optionalAttr) { - m["try_recover_truncated"] = value - } -} - -// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value. -// -// value: The minimum required fraction of lines before a truncated -// input is accepted. -// If not specified, defaults to 1 -func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { - return func(m optionalAttr) { - m["acceptable_fraction"] = value - } -} - -// DecodeJpegDctMethod sets the optional dct_method attribute to value. -// -// value: string specifying a hint about the algorithm used for -// decompression. Defaults to "" which maps to a system-specific -// default. Currently valid values are ["INTEGER_FAST", -// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal -// jpeg library changes to a version that does not have that specific -// option.) -// If not specified, defaults to "" -func DecodeJpegDctMethod(value string) DecodeJpegAttr { - return func(m optionalAttr) { - m["dct_method"] = value - } -} - -// Decode a JPEG-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the JPEG-encoded image. -// * 1: output a grayscale image. -// * 3: output an RGB image. -// -// If needed, the JPEG-encoded image is transformed to match the requested number -// of color channels. -// -// The attr `ratio` allows downscaling the image by an integer factor during -// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than -// downscaling the image later. -// -// Arguments: -// contents: 0-D. The JPEG-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`.. -func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeJpeg", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. -type TensorArrayGatherV3Attr func(optionalAttr) - -// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. -// -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// Gather specific elements from the TensorArray into output `value`. -// -// All elements selected by `indices` must have the same shape. -// -// Arguments: -// handle: The handle to a TensorArray. -// indices: The locations in the TensorArray from which to read tensor elements. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. -// -// Returns All of the elements in the TensorArray, concatenated along a new -// axis (the new dimension 0). -func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorArrayGatherV3", - Input: []tf.Input{ - handle, indices, flow_in, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. -type MaxPoolGradGradAttr func(optionalAttr) - -// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 3D real-valued fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform of a real-valued signal -// over the inner-most 3 dimensions of `input`. -// -// Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the -// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension -// of `output`: the zero-frequency term, followed by the `fft_length / 2` -// positive-frequency terms. -// -// Arguments: -// input: A float32 tensor. -// fft_length: An int32 tensor of shape [3]. The FFT length for each dimension. -// -// Returns A complex64 tensor of the same rank as `input`. The inner-most 3 -// dimensions of `input` are replaced with the their 3D Fourier transform. The -// inner-most dimension contains `fft_length / 2 + 1` unique frequency -// components. -// -// @compatibility(numpy) -// Equivalent to np.fft.rfftn with 3 dimensions. -// @end_compatibility -func RFFT3D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RFFT3D", - Input: []tf.Input{ - input, fft_length, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deprecated. Disallowed in GraphDef version >= 2. -// -// DEPRECATED at GraphDef version 2: Use AdjustContrastv2 instead -func AdjustContrast(scope *Scope, images tf.Output, contrast_factor tf.Output, min_value tf.Output, max_value tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustContrast", - Input: []tf.Input{ - images, contrast_factor, min_value, max_value, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Store the input tensor in the state of the current session. // // Arguments: @@ -21419,25 +21649,6 @@ func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) { return op.Output(0) } -// Restore a Reader to its initial clean state. -// -// Arguments: -// reader_handle: Handle to a Reader. -// -// Returns the created operation. -func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderResetV2", - Input: []tf.Input{ - reader_handle, - }, - } - return scope.AddOperation(opspec) -} - // Adjust the hue of one or more images. // // `images` is a tensor of at least 3 dimensions. The last dimension is @@ -21466,232 +21677,21 @@ func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Outpu return op.Output(0) } -// SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. -type SelfAdjointEigV2Attr func(optionalAttr) - -// SelfAdjointEigV2ComputeV sets the optional compute_v attribute to value. -// -// value: If `True` then eigenvectors will be computed and returned in `v`. -// Otherwise, only the eigenvalues will be computed. -// If not specified, defaults to true -func SelfAdjointEigV2ComputeV(value bool) SelfAdjointEigV2Attr { - return func(m optionalAttr) { - m["compute_v"] = value - } -} - -// Computes the eigen decomposition of one or more square self-adjoint matrices. -// -// Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in -// `input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. -// -// ```prettyprint -// # a is a tensor. -// # e is a tensor of eigenvalues. -// # v is a tensor of eigenvectors. -// e, v = self_adjoint_eig(a) -// e = self_adjoint_eig(a, compute_v=False) -// ``` +// Restore a Reader to its initial clean state. // // Arguments: -// input: `Tensor` input of shape `[N, N]`. +// reader_handle: Handle to a Reader. // -// Returns Eigenvalues. Shape is `[N]`.Eigenvectors. Shape is `[N, N]`. -func SelfAdjointEigV2(scope *Scope, input tf.Output, optional ...SelfAdjointEigV2Attr) (e tf.Output, v tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SelfAdjointEigV2", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Adjust the saturation of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. -// -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A scale is then applied all the saturation -// values, and then remapped back to RGB colorspace. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// scale: A float scale to add to the saturation. -// -// Returns The hue-adjusted image or images. -func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output tf.Output) { +// Returns the created operation. +func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AdjustSaturation", + Type: "ReaderResetV2", Input: []tf.Input{ - images, scale, + reader_handle, }, } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// EncodePngAttr is an optional argument to EncodePng. -type EncodePngAttr func(optionalAttr) - -// EncodePngCompression sets the optional compression attribute to value. -// -// value: Compression level. -// If not specified, defaults to -1 -func EncodePngCompression(value int64) EncodePngAttr { - return func(m optionalAttr) { - m["compression"] = value - } -} - -// PNG-encode an image. -// -// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]` -// where `channels` is: -// -// * 1: for grayscale. -// * 2: for grayscale + alpha. -// * 3: for RGB. -// * 4: for RGBA. -// -// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder -// default or a value from 0 to 9. 9 is the highest compression level, generating -// the smallest output, but is slower. -// -// Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. PNG-encoded image. -func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "EncodePng", - Input: []tf.Input{ - image, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. -type MatrixSolveLsAttr func(optionalAttr) - -// MatrixSolveLsFast sets the optional fast attribute to value. -// If not specified, defaults to true -func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { - return func(m optionalAttr) { - m["fast"] = value - } -} - -// Solves one or more linear least-squares problems. -// -// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`. -// The output is a tensor shape `[..., N, K]` where each output matrix solves -// each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] -// in the least squares sense. -// -// matrix and right-hand sides in the batch: -// -// `matrix`=\\(A \in \Re^{m \times n}\\), -// `rhs`=\\(B \in \Re^{m \times k}\\), -// `output`=\\(X \in \Re^{n \times k}\\), -// `l2_regularizer`=\\(\lambda\\). -// -// If `fast` is `True`, then the solution is computed by solving the normal -// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then -// \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares -// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + -// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as -// \\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the -// minimum-norm solution to the under-determined linear system, i.e. -// \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to -// \\(A Z = B\\). Notice that the fast path is only numerically stable when -// \\(A\\) is numerically full rank and has a condition number -// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is -// sufficiently large. -// -// If `fast` is `False` an algorithm based on the numerically robust complete -// orthogonal decomposition is used. This computes the minimum-norm -// least-squares solution, even when \\(A\\) is rank deficient. This path is -// typically 6-7 times slower than the fast path. If `fast` is `False` then -// `l2_regularizer` is ignored. -// -// Arguments: -// matrix: Shape is `[..., M, N]`. -// rhs: Shape is `[..., M, K]`. -// l2_regularizer: Scalar tensor. -// -// @compatibility(numpy) -// Equivalent to np.linalg.lstsq -// @end_compatibility -// -// Returns Shape is `[..., N, K]`. -func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixSolveLs", - Input: []tf.Input{ - matrix, rhs, l2_regularizer, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Converts one or more images from RGB to HSV. -// -// Outputs a tensor of the same shape as the `images` tensor, containing the HSV -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. -// -// `output[..., 0]` contains hue, `output[..., 1]` contains saturation, and -// `output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 -// corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. -// -// Arguments: -// images: 1-D or higher rank. RGB data to convert. Last dimension must be size 3. -// -// Returns `images` converted to HSV. -func RGBToHSV(scope *Scope, images tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RGBToHSV", - Input: []tf.Input{ - images, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } From 9b336b4a33158061535fd6ba4973605248055b69 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 4 May 2017 13:05:05 -0800 Subject: [PATCH 29/43] Open sourced op level cost prediction Change: 155123817 --- tensorflow/core/grappler/costs/BUILD | 27 + .../grappler/costs/op_level_cost_estimator.cc | 554 ++++++++++++++++++ .../grappler/costs/op_level_cost_estimator.h | 143 +++++ .../costs/op_level_cost_estimator_test.cc | 113 ++++ tensorflow/core/grappler/costs/utils.cc | 4 +- 5 files changed, 840 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/grappler/costs/op_level_cost_estimator.cc create mode 100644 tensorflow/core/grappler/costs/op_level_cost_estimator.h create mode 100644 tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 22f4708d032..372092f42a9 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -111,6 +111,7 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], + defines = if_cuda(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ ":op_performance_data_cc", @@ -167,3 +168,29 @@ cc_library( "//tensorflow/core/kernels:ops_util", ], ) + +cc_library( + name = "op_level_cost_estimator", + srcs = ["op_level_cost_estimator.cc"], + hdrs = ["op_level_cost_estimator.h"], + visibility = ["//visibility:public"], + deps = [ + ":cost_estimator", + ":op_performance_data_cc", + ":utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + ], +) + +cc_test( + name = "op_level_cost_estimator_test", + srcs = ["op_level_cost_estimator_test.cc"], + deps = [ + ":op_level_cost_estimator", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc new file mode 100644 index 00000000000..baed7a88997 --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -0,0 +1,554 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/grappler/costs/utils.h" + +namespace tensorflow { +namespace grappler { + +constexpr int kOpsPerMac = 2; +constexpr char kConv2d[] = "Conv2D"; +constexpr char kConv2dBackPropFilter[] = "Conv2DBackpropFilter"; +constexpr char kConv2dBackPropInput[] = "Conv2DBackpropInput"; +constexpr char kMatMul[] = "MatMul"; +constexpr char kSparseMatMul[] = "SparseMatMul"; +constexpr char kIdentity[] = "Identity"; +constexpr char kNoOp[] = "NoOp"; +constexpr char kReshape[] = "Reshape"; + +OpLevelCostEstimator::OpLevelCostEstimator() { + // Syntactic sugar to build and return a lambda that takes an OpInfo and + // returns a cost. + typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpInfo& op_feature) + const; + auto wrap = [this](CostImpl impl) -> std::function { + return [this, impl](const OpInfo& op) { return (this->*impl)(op); }; + }; + + device_cost_impl_ = { + {kConv2d, wrap(&OpLevelCostEstimator::PredictConv2D)}, + {kConv2dBackPropFilter, + wrap(&OpLevelCostEstimator::PredictConv2DBackPropFilter)}, + {kConv2dBackPropInput, + wrap(&OpLevelCostEstimator::PredictConv2DBackPropInput)}, + {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, + {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, + {kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}}; +} + +Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const { + auto it = device_cost_impl_.find(op_features.op()); + if (it == device_cost_impl_.end()) { + VLOG(1) << "Missing implementation for op: " << op_features.op(); + Costs costs; + costs = DummyExecutionTime(op_features); + return costs; + } + + std::function estimator = it->second; + Costs costs = estimator(op_features); + VLOG(1) << "Operation " << op_features.op() << " takes " + << costs.execution_time.count() << " ns."; + return costs; +} + +std::pair OpLevelCostEstimator::GetDeviceInfo( + const OpInfo::DeviceProperties& device) const { + double gflops = -1; + double bandwidth = -1; + if (device.bandwidth() > 0) { + bandwidth = device.bandwidth() / 1e6; + } + + if (device.type() == "CPU") { + const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo(); + // Check if vector instructions are available, and refine performance + // prediction based on this. + gflops = local_cpu.num_cores() * local_cpu.frequency(); + if (bandwidth < 0) { + if (local_cpu.bandwidth() > 0) { + bandwidth = local_cpu.bandwidth() / 1e6; + } else { + bandwidth = 32; + } + } + } else if (device.type() == "GPU") { + const OpInfo::DeviceProperties local_gpu = GetLocalGPUInfo(0); + const string architecture = local_gpu.environment().at("architecture"); + int cores_per_multiprocessor; + if (architecture < "3") { + // Fermi + cores_per_multiprocessor = 32; + } else if (architecture < "4") { + // Kepler + cores_per_multiprocessor = 192; + } else if (architecture < "6") { + // Maxwell + cores_per_multiprocessor = 128; + } else { + // Pascal. + cores_per_multiprocessor = 64; + } + gflops = local_gpu.num_cores() * local_gpu.frequency() * + cores_per_multiprocessor * kOpsPerMac; + if (bandwidth < 0) { + CHECK(local_gpu.bandwidth() > 0); + bandwidth = local_gpu.bandwidth() / 1e6; + } + } + + return std::make_pair(gflops, bandwidth); +} + +Costs OpLevelCostEstimator::DummyExecutionTime( + const OpInfo& op_features) const { + Costs costs = PredictOpCountBasedCost(0, op_features); + costs.inaccurate = true; + return costs; +} + +Costs OpLevelCostEstimator::PredictOpCountBasedCost( + double operations, const OpInfo& op_features) const { + std::pair device_perf = GetDeviceInfo(op_features.device()); + Costs::NanoSeconds compute_cost(operations / device_perf.first); + VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9 + << " Execution Time (ns):" << compute_cost.count(); + + bool found_unknown_shapes = false; + double total_input_size = + CalculateInputSize(op_features, &found_unknown_shapes); + double total_output_size = + CalculateOutputSize(op_features, &found_unknown_shapes); + double total_io_size = total_input_size + total_output_size; + + Costs::NanoSeconds memory_cost(total_io_size / device_perf.second); + VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3 + << " Memory Time (ns):" << memory_cost.count(); + + Costs costs; + costs.compute_time = compute_cost; + costs.memory_time = memory_cost; + costs.execution_time = compute_cost + memory_cost; + costs.inaccurate = found_unknown_shapes; + return costs; +} + +int64 OpLevelCostEstimator::CountConv2DOperations( + const OpInfo& op_features, bool* found_unknown_shapes) const { + return CountConv2DOperations(op_features, nullptr, found_unknown_shapes); +} + +namespace { + +string GetDataFormat(const OpInfo& op_features) { + string data_format = "NHWC"; // Default format. + if (op_features.attr().find("data_format") != op_features.attr().end()) { + data_format = op_features.attr().at("data_format").s(); + } + return data_format; +} + +Padding GetPadding(const OpInfo& op_features) { + if (op_features.attr().find("padding") != op_features.attr().end() && + op_features.attr().at("padding").s() == "VALID") { + return Padding::VALID; + } + return Padding::SAME; // Default padding. +} + +std::vector GetStrides(const OpInfo& op_features) { + if (op_features.attr().find("strides") != op_features.attr().end()) { + const auto strides = op_features.attr().at("strides").list().i(); + return {strides[0], strides[1], strides[2], strides[3]}; + } + return {1, 1, 1, 1}; +} + +int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride, + const Padding& padding) { + // Logic for calculating output shape is from GetWindowedOutputSizeVerbose() + // function in third_party/tensorflow/core/framework/common_shape_fns.cc. + if (padding == Padding::VALID) { + return (input - filter + stride) / stride; + } else { // SAME. + return (input + stride - 1) / stride; + } +} + +// Return a minimum shape if the shape is unknown. If known, return the original +// shape. +TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, + int rank, bool* found_unknown_shapes) { + auto shape = original_shape; + if (shape.unknown_rank()) { + *found_unknown_shapes = true; + } + if (shape.unknown_rank() || shape.dim_size() == 0) { + TensorShapeProto::Dim dim; + VLOG(1) << "WARNING: Use minimum shape because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + dim.set_size(1); + for (int i = 0; i < rank; i++) { + *shape.add_dim() = dim; + } + } else { + CHECK_EQ(shape.dim_size(), rank); + for (int i = 0; i < rank; i++) { + if (shape.dim(i).size() == -1) { + *found_unknown_shapes = true; + VLOG(1) + << "WARNING: Use minimum dim size 1 because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + shape.mutable_dim(i)->set_size(1); + } + } + } + return shape; +} +} // namespace + +// Helper to translate the positional arguments into named fields. +OpLevelCostEstimator::ConvolutionDimensions +OpLevelCostEstimator::ConvolutionDimensionsFromInputs( + const TensorShapeProto& original_image_shape, + const TensorShapeProto& original_filter_shape, const OpInfo& op_features, + bool* found_unknown_shapes) { + auto image_shape = + MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes); + auto filter_shape = + MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes); + + int x_index, y_index, channel_index; + const string& data_format = GetDataFormat(op_features); + if (data_format == "NCHW") { + x_index = 2; + y_index = 3; + channel_index = 1; + } else { + x_index = 1; + y_index = 2; + channel_index = 3; + } + int64 batch = image_shape.dim(0).size(); + int64 ix = image_shape.dim(x_index).size(); + int64 iy = image_shape.dim(y_index).size(); + int64 iz = image_shape.dim(channel_index).size(); + int64 kx = filter_shape.dim(0).size(); + int64 ky = filter_shape.dim(1).size(); + std::vector strides = GetStrides(op_features); + const auto padding = GetPadding(op_features); + int64 sx = strides[x_index]; + int64 sy = strides[y_index]; + int64 ox = GetOutputSize(ix, kx, sx, padding); + int64 oy = GetOutputSize(iy, ky, sy, padding); + int64 oz = filter_shape.dim(3).size(); + // Only check equality when both sizes are known (in other words, when + // neither is set to a minimum dimension size of 1). + if (iz != 1 && filter_shape.dim(2).size() != 1) { + CHECK_EQ(iz, filter_shape.dim(2).size()); + } else { + iz = std::max(iz, filter_shape.dim(2).size()); + } + OpLevelCostEstimator::ConvolutionDimensions conv_dims = { + batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding}; + + VLOG(1) << "Batch Size:" << batch; + VLOG(1) << "Image Dims:" << ix << "," << iy; + VLOG(1) << "Input Features:" << iz; + VLOG(1) << "Kernel Dims:" << kx << "," << ky; + VLOG(1) << "Output Features:" << oz; + VLOG(1) << "Output Dims:" << ox << "," << oy; + VLOG(1) << "Strides:" << sx << "," << sy; + VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME"); + return conv_dims; +} + +int64 OpLevelCostEstimator::CountConv2DOperations( + const OpInfo& op_features, ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const { + if (op_features.op() != kConv2d) { + LOG(ERROR) << "Invalid Operation"; + return 0; + } + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features, + found_unknown_shapes); + + int64 ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + VLOG(1) << "Operations for Conv2D" << ops; + + if (conv_info != nullptr) { + *conv_info = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CountMatMulOperations( + const OpInfo& op_features, bool* found_unknown_shapes) const { + return CountMatMulOperations(op_features, nullptr, found_unknown_shapes); +} + +int64 OpLevelCostEstimator::CountMatMulOperations( + const OpInfo& op_features, MatMulDimensions* mat_mul, + bool* found_unknown_shapes) const { + double ops = 0; + + // TODO(nishantpatil): Create separate estimator for Sparse Matmul + if ((op_features.op() != kMatMul) && (op_features.op() != kSparseMatMul)) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + // first matrix + auto& a_matrix = op_features.inputs(0); + auto& b_matrix = op_features.inputs(1); + + bool transpose_a = false; + bool transpose_b = false; + + double m_dim, n_dim, k_dim, k_dim_b = 0; + + for (const auto& item : op_features.attr()) { + VLOG(1) << "Key:" << item.first + << " Value:" << SummarizeAttrValue(item.second); + if (item.first == "transpose_a" && item.second.b() == true) + transpose_a = true; + if (item.first == "transpose_b" && item.second.b() == true) + transpose_b = true; + } + VLOG(1) << "transpose_a:" << transpose_a; + VLOG(1) << "transpose_b:" << transpose_b; + auto a_matrix_shape = + MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes); + auto b_matrix_shape = + MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes); + if (transpose_a) { + m_dim = a_matrix_shape.dim(1).size(); + k_dim = a_matrix_shape.dim(0).size(); + } else { + m_dim = a_matrix_shape.dim(0).size(); + k_dim = a_matrix_shape.dim(1).size(); + } + if (transpose_b) { + k_dim_b = b_matrix_shape.dim(1).size(); + n_dim = b_matrix_shape.dim(0).size(); + } else { + k_dim_b = b_matrix_shape.dim(0).size(); + n_dim = b_matrix_shape.dim(1).size(); + } + + VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim; + // Only check equality when both sizes are known (in other words, when + // neither is set to a minimum dimension size of 1). + if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) { + LOG(ERROR) << "Incompatible Matrix dimensions"; + return ops; + } else { + // One of k_dim and k_dim_b might be 1 (mininum dimension size). + k_dim = std::max(k_dim, k_dim_b); + } + + ops = m_dim * n_dim * k_dim * 2; + VLOG(1) << "Operations for Matmul" << ops; + + if (mat_mul != nullptr) { + mat_mul->m = m_dim; + mat_mul->n = n_dim; + mat_mul->k = k_dim; + } + return ops; +} + +// TODO(cliffy): Dedup this method and CountConv2DBackPropFilterOperations. +int64 OpLevelCostEstimator::CountConv2DBackPropInputOperations( + const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, + bool* found_unknown_shapes) const { + int64 ops = 0; + + if (op_features.op() != kConv2dBackPropInput) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { + // Need _output_shapes for input shape. + LOG(ERROR) << "No output shape in Conv2DBackPropInput op feaure."; + return ops; + } + + const auto& input_shape = + op_features.attr().at("_output_shapes").list().shape(0); + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + input_shape, op_features.inputs(1).shape(), op_features, + found_unknown_shapes); + + ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + + VLOG(1) << "Operations for Conv2DBackPropInput" << ops; + + if (returned_conv_dims != nullptr) { + *returned_conv_dims = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CountConv2DBackPropFilterOperations( + const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, + bool* found_unknown_shapes) const { + int64 ops = 0; + if (op_features.op() != kConv2dBackPropFilter) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { + // Need _output_shapes for filter shape. + LOG(ERROR) << "No output shape in Conv2DBackPropFilter op feaure."; + return ops; + } + + const auto& filter_shape = + op_features.attr().at("_output_shapes").list().shape(0); + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + op_features.inputs(0).shape(), filter_shape, op_features, + found_unknown_shapes); + + ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + + VLOG(1) << "Operations for Conv2DBackPropFilter" << ops; + + if (returned_conv_dims != nullptr) { + *returned_conv_dims = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CalculateSingleInputSize( + const OpInfo::TensorProperties& input, bool* found_unknown_shapes) const { + VLOG(1) << " with " << input.dtype() << " input of shape " + << input.shape().DebugString(); + int64 input_size = 1; + int num_dims = std::max(1, input.shape().dim_size()); + auto input_shape = + MaybeGetMinimumShape(input.shape(), num_dims, found_unknown_shapes); + for (const auto& dim : input_shape.dim()) { + input_size *= dim.size(); + } + return input_size * DataTypeSize(input.dtype()); +} + +int64 OpLevelCostEstimator::CalculateInputSize( + const OpInfo& op_features, bool* found_unknown_shapes) const { + int64 total_input_size = 0; + for (auto& input : op_features.inputs()) { + int64 input_size = CalculateSingleInputSize(input, found_unknown_shapes); + total_input_size += input_size; + VLOG(1) << "Input Size: " << input_size + << " Total Input Size:" << total_input_size; + } + return total_input_size; +} + +int64 OpLevelCostEstimator::CalculateOutputSize( + const OpInfo& op_features, bool* found_unknown_shapes) const { + int64 total_output_size = 0; + // use float as default for calculations + DataType dt = DT_FLOAT; + for (const auto& item : op_features.attr()) { + VLOG(1) << "Key:" << item.first + << " Value:" << SummarizeAttrValue(item.second); + if (item.first == "_output_shapes") { + for (const auto& original_output_shape : item.second.list().shape()) { + int64 output_size = 1; + int num_dims = std::max(1, original_output_shape.dim_size()); + auto output_shape = MaybeGetMinimumShape( + original_output_shape, num_dims, found_unknown_shapes); + for (const auto& dim : output_shape.dim()) { + output_size *= dim.size(); + } + output_size *= DataTypeSize(dt); + total_output_size += output_size; + VLOG(1) << "Output Size: " << output_size + << " Total Output Size:" << total_output_size; + } + } + if (item.first == "T") { + dt = item.second.type(); + } + } + return total_output_size; +} + +Costs OpLevelCostEstimator::PredictConv2D(const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = PredictOpCountBasedCost( + CountConv2DOperations(op_features, &found_unknown_shapes), op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictConv2DBackPropInput( + const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = + PredictOpCountBasedCost(CountConv2DBackPropInputOperations( + op_features, nullptr, &found_unknown_shapes), + op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictConv2DBackPropFilter( + const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = + PredictOpCountBasedCost(CountConv2DBackPropFilterOperations( + op_features, nullptr, &found_unknown_shapes), + op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictMatMul(const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = PredictOpCountBasedCost( + CountMatMulOperations(op_features, &found_unknown_shapes), op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictNoOp(const OpInfo& op_features) const { + VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)"; + return Costs::ZeroCosts(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h new file mode 100644 index 00000000000..5bb20cc6bbf --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ + +#include +#include +#include + +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { +namespace grappler { + +class OpLevelCostEstimator { + public: + OpLevelCostEstimator(); + virtual ~OpLevelCostEstimator() {} + + Costs PredictCosts(const OpInfo& op_features) const; + + protected: + // Returns an estimate of device performance (in billions of operations + // executed per second) and memory bandwith (in GigaBytes/second) for the + // specified device. + virtual std::pair GetDeviceInfo( + const OpInfo::DeviceProperties& device) const; + + // For operations for which we haven't yet built estimates, returns a dummy + // value based on input size. + Costs DummyExecutionTime(const OpInfo& op_features) const; + + // Naive cost estimate based on operations divided by device ops/sec. + Costs PredictOpCountBasedCost(double operations, + const OpInfo& op_features) const; + + // This family of routines counts the number of operations to perform the + // specified TensorFlow Op. + struct MatMulDimensions { + int m; + int n; + int k; + }; + struct ConvolutionDimensions { + int64 batch; // Batch size. + int64 ix; // Input size x. + int64 iy; // Input size y. + int64 iz; // Input depth. + int64 kx; // Kernel x. + int64 ky; // Kernel y. + int64 oz; // Output depth. + int64 ox; // Output size x. + int64 oy; // Output size y. + int64 sx; // Stride x. + int64 sy; // Stride y. + Padding padding; // SAME or VALID. + }; + int64 CountConv2DOperations(const OpInfo& op_features, + bool* found_unknown_shapes) const; + int64 CountConv2DOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + int64 CountMatMulOperations(const OpInfo& op_features, + bool* found_unknown_shapes) const; + int64 CountMatMulOperations(const OpInfo& op_features, + MatMulDimensions* mat_mul, + bool* found_unknown_shapes) const; + int64 CountConv2DBackPropInputOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + int64 CountConv2DBackPropFilterOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of a single input to a TensorFlow op. + int64 CalculateSingleInputSize(const OpInfo::TensorProperties& input, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of the all + // the inputs of specified TensorFlow Op + int64 CalculateInputSize(const OpInfo& op_features, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of the all + // the outputs of specified TensorFlow Op + int64 CalculateOutputSize(const OpInfo& op_features, + bool* found_unknown_shapes) const; + + // This family of routines predicts the costs to + // perform the specified TensorFlow Op on the + // device represented by a subclass. The default + // implementation just divides the operations to + // perform the op (from the "Count" routines, + // above) by the device peak operations per + // second. Override to supply a better estimate. + // Implementation of costs other than + // execution_time is optional, depending on the + // device. + Costs PredictConv2D(const OpInfo& op_features) const; + Costs PredictConv2DBackPropInput(const OpInfo& op_features) const; + Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const; + Costs PredictMatMul(const OpInfo& op_features) const; + Costs PredictNoOp(const OpInfo& op_features) const; + + // Utility function for safe division. Returns 0 + // if rhs is 0 or negative. + static double SafeDiv(const double lhs, const double rhs) { + if (rhs > 0) { + return lhs / rhs; + } else { + return 0.0; + } + } + + static ConvolutionDimensions ConvolutionDimensionsFromInputs( + const TensorShapeProto& original_image_shape, + const TensorShapeProto& original_filter_shape, const OpInfo& op_features, + bool* found_unknown_shapes); + + private: + typedef std::function CostImpl; + std::map device_cost_impl_; +}; + +} // end namespace grappler +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc new file mode 100644 index 00000000000..e0b0348c8ec --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +namespace { +// Wrangles the minimum number of proto fields to set up a matrix. +void DescribeMatrix(int rows, int columns, OpInfo *op_features) { + auto input = op_features->add_inputs(); + auto shape = input->mutable_shape(); + auto shape_rows = shape->add_dim(); + shape_rows->set_size(rows); + auto shape_columns = shape->add_dim(); + shape_columns->set_size(columns); + input->set_dtype(DT_FLOAT); +} + +// Returns an OpInfo for MatMul with the minimum set of fields set up. +OpInfo DescribeMatMul(int m, int n, int l, int k) { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("MatMul"); + + DescribeMatrix(m, l, &op_features); + DescribeMatrix(k, n, &op_features); + return op_features; +} + +// Returns an OpInfo for MatMul with unknown input shapes. +OpInfo DescribeMatMulUnknownShape() { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("MatMul"); + + auto input = op_features.add_inputs(); + auto shape = input->mutable_shape(); + shape->set_unknown_rank(true); + + input = op_features.add_inputs(); + shape = input->mutable_shape(); + shape->set_unknown_rank(true); + + return op_features; +} + +// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost +// estimation purposes. +void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3, + OpInfo *op_features) { + auto input = op_features->add_inputs(); + auto shape = input->mutable_shape(); + shape->add_dim()->set_size(dim0); + shape->add_dim()->set_size(dim1); + shape->add_dim()->set_size(dim2); + shape->add_dim()->set_size(dim3); +} + +// Returns an OpInfo for Conv2D with the minimum set of fields set up. +OpInfo DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, int kx, + int ky, int oz) { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("Conv2D"); + + DescribeTensor4D(batch, ix, iy, iz1, &op_features); + DescribeTensor4D(kx, ky, iz2, oz, &op_features); + return op_features; +} +} // namespace + +TEST(OpLevelCostEstimatorTest, UnknownOrPartialShape) { + OpLevelCostEstimator estimator; + + EXPECT_EQ(false, + estimator.PredictCosts(DescribeMatMul(2, 4, 7, 7)).inaccurate); + EXPECT_EQ(true, + estimator.PredictCosts(DescribeMatMul(-1, 4, 7, 7)).inaccurate); + EXPECT_EQ(true, + estimator.PredictCosts(DescribeMatMul(2, 4, -1, 7)).inaccurate); + + EXPECT_EQ( + false, + estimator.PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256)) + .inaccurate); + EXPECT_EQ( + true, + estimator.PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256)) + .inaccurate); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 4e35de9d4a6..0852cb4fd3a 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -147,7 +147,7 @@ OpInfo::DeviceProperties GetLocalCPUInfo() { // Combine cpu family and model into the model string. device.set_model( strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum())); - device.set_frequency(port::NominalCPUFrequency()); + device.set_frequency(port::NominalCPUFrequency() * 1e-9); device.set_num_cores(port::NumSchedulableCPUs()); device.set_l1_cache_size(Eigen::l1CacheSize()); device.set_l2_cache_size(Eigen::l2CacheSize()); @@ -195,6 +195,8 @@ OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) { properties.memoryClockRate * 2); } + (*device.mutable_environment())["architecture"] = + strings::StrCat(properties.major, ".", properties.minor); (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); #endif From 8db659168077605d4dc9846f47e730b5fb05f5e4 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Thu, 4 May 2017 13:51:15 -0800 Subject: [PATCH 30/43] Upgrade TypeScript to 2.3.1 Change: 155130334 --- tensorflow/tensorboard/package.json | 2 +- tensorflow/workspace.bzl | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json index 69f08495a30..d424f103dd7 100644 --- a/tensorflow/tensorboard/package.json +++ b/tensorflow/tensorboard/package.json @@ -30,7 +30,7 @@ "merge2": "~0.3.6", "minimist": "~1.2.0", "tsify": "^0.14.8", - "typescript": "2.2.2", + "typescript": "2.3.1", "typings": "1.4.0", "vinyl-source-stream": "^1.1.0", "vulcanize": "^1.14.0", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 0f69d53ba4e..50e329f8c09 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -668,13 +668,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_microsoft_typescript", licenses = ["notice"], # Apache 2.0 sha256_urls = { - "43a7c763fe024d5add8d5365e5a7981f4a359ba5bf86481f545a0db8f60d48cc": [ - "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/tsc.js", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/tsc.js", + "8465342c318f9c4cf0a29b109fa63ee3742dd4dc7080d05d9fd8f604814d04cf": [ + "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", + "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", ], - "aecec1e47a3b3d872e214cb9adb82b30d6bd0471ea0aad7311ad81428566627c": [ - "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/lib.es6.d.ts", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.2.2/lib/lib.es6.d.ts", + "a67e36da3029d232e4e938e61a0a3302f516d71e7100d54dbf5362ad8618e994": [ + "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", + "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", ], }, extra_build_file_content = "\n".join([ From d1001f4bedf91f50971174f9131d5b72d6120d73 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 4 May 2017 13:54:01 -0800 Subject: [PATCH 31/43] [tf distributions] Move conditional back to contrib (not slated for core); move transformed into core. (I accidentally moved the wrong file in the previous change!) Also move the Identity bijector & test into core. I can't move the TransformedDistribution test into core since it relies on linalg. Change: 155130709 --- tensorflow/contrib/distributions/BUILD | 19 ------------ tensorflow/contrib/distributions/__init__.py | 4 +-- .../kernel_tests/bijectors/affine_test.py | 2 +- .../kernel_tests/bijectors/chain_test.py | 2 +- .../bijectors/cholesky_outer_product_test.py | 4 +-- .../python/kernel_tests/bijectors/exp_test.py | 4 +-- .../kernel_tests/bijectors/invert_test.py | 4 +-- .../bijectors/power_transform_test.py | 4 +-- .../kernel_tests/bijectors/sigmoid_test.py | 4 +-- .../bijectors/softmax_centered_test.py | 2 +- .../kernel_tests/bijectors/softplus_test.py | 4 +-- .../python/ops/bijectors/__init__.py | 2 +- .../python/ops/bijectors/identity.py | 29 ------------------- .../python/ops}/conditional_distribution.py | 0 .../conditional_transformed_distribution.py | 4 +-- .../python/ops/mvn_linear_operator.py | 2 +- .../python/ops/relaxed_bernoulli.py | 2 +- .../python/ops/relaxed_onehot_categorical.py | 2 +- .../python/ops/vector_student_t.py | 2 +- .../python/kernel_tests/distributions/BUILD | 17 +++++++++++ .../distributions/identity_bijector_test.py} | 10 +++---- .../ops/distributions}/bijector_test_util.py | 0 .../ops/distributions/identity_bijector.py} | 0 .../transformed_distribution.py | 4 +-- 24 files changed, 48 insertions(+), 79 deletions(-) delete mode 100644 tensorflow/contrib/distributions/python/ops/bijectors/identity.py rename tensorflow/{python/ops/distributions => contrib/distributions/python/ops}/conditional_distribution.py (100%) rename tensorflow/{contrib/distributions/python/kernel_tests/bijectors/identity_test.py => python/kernel_tests/distributions/identity_bijector_test.py} (84%) rename tensorflow/{contrib/distributions/python/ops/bijectors => python/ops/distributions}/bijector_test_util.py (100%) rename tensorflow/{contrib/distributions/python/ops/bijectors/identity_impl.py => python/ops/distributions/identity_bijector.py} (100%) rename tensorflow/{contrib/distributions/python/ops => python/ops/distributions}/transformed_distribution.py (99%) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 9f675c66135..0c818dee031 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -710,25 +710,6 @@ cuda_py_test( ], ) -cuda_py_test( - name = "identity_test", - size = "small", - srcs = ["python/kernel_tests/bijectors/identity_test.py"], - additional_deps = [ - ":bijectors_py", - ":distributions_py", - "//third_party/py/numpy", - "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - cuda_py_test( name = "inline_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 6ea74fab0e4..ea12e13010a 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -25,6 +25,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.chi2 import * +from tensorflow.contrib.distributions.python.ops.conditional_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform @@ -44,12 +45,10 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.sample_stats import * -from tensorflow.contrib.distributions.python.ops.transformed_distribution import * from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.python.ops.distributions.bernoulli import * from tensorflow.python.ops.distributions.beta import * from tensorflow.python.ops.distributions.categorical import * -from tensorflow.python.ops.distributions.conditional_distribution import * from tensorflow.python.ops.distributions.dirichlet import * from tensorflow.python.ops.distributions.dirichlet_multinomial import * from tensorflow.python.ops.distributions.distribution import * @@ -60,6 +59,7 @@ from tensorflow.python.ops.distributions.laplace import * from tensorflow.python.ops.distributions.multinomial import * from tensorflow.python.ops.distributions.normal import * from tensorflow.python.ops.distributions.student_t import * +from tensorflow.python.ops.distributions.transformed_distribution import * from tensorflow.python.ops.distributions.uniform import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index 13554f76642..e8fd6aa2f73 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -23,9 +23,9 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index 994e21dd487..20e75430844 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -20,12 +20,12 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index a4688829f1f..0ff35304283 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import gamma as gamma_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py index c30ce60cacc..9970c0b4d86 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py index a4688829f1f..0ff35304283 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import gamma as gamma_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py index b30a3b599bb..de1659aa9f4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import PowerTransform +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py index 6f1a6b1cf4b..e4f9d72785c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import numpy as np from scipy import special -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py index 173d52686d6..62e3869db09 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py index 214b196b547..d9af9aec50d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test rng = np.random.RandomState(42) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index e1d31e373cc..1684a5fffe1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -43,7 +43,6 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * -from tensorflow.contrib.distributions.python.ops.bijectors.identity import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * @@ -52,6 +51,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered impo from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * from tensorflow.python.ops.distributions.bijector import * +from tensorflow.python.ops.distributions.identity_bijector import Identity # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/identity.py b/tensorflow/contrib/distributions/python/ops/bijectors/identity.py deleted file mode 100644 index 749dd268f98..00000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/identity.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Identity bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.identity_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ["Identity"] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/ops/distributions/conditional_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_distribution.py similarity index 100% rename from tensorflow/python/ops/distributions/conditional_distribution.py rename to tensorflow/contrib/distributions/python/ops/conditional_distribution.py diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index b0967802bd8..2e1e68cf058 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -17,9 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops import transformed_distribution +from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import conditional_distribution +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index a66eb1674ca..fbd623ed3a1 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -29,6 +28,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 581e190f73b..5b57a95c55e 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import logistic -from tensorflow.contrib.distributions.python.ops import transformed_distribution # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid @@ -27,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 00415f5e1aa..da1cd72a6f1 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -30,6 +29,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 299ff36962e..ae804b61727 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops import transformed_distribution from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import student_t +from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index 3630adc9549..50a07952004 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -249,6 +249,23 @@ cuda_py_test( ], ) +cuda_py_test( + name = "identity_bijector_test", + size = "small", + srcs = ["identity_bijector_test.py"], + additional_deps = [ + "//tensorflow/python/ops/distributions", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py similarity index 84% rename from tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py rename to tensorflow/python/kernel_tests/distributions/identity_bijector_test.py index 0969c293d40..e8f9d0b728d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/identity_test.py +++ b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency -from tensorflow.contrib.distributions.python.ops.bijectors.identity import Identity +from tensorflow.python.ops.distributions import bijector_test_util +from tensorflow.python.ops.distributions import identity_bijector from tensorflow.python.platform import test @@ -28,7 +28,7 @@ class IdentityBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - bijector = Identity() + bijector = identity_bijector.Identity() self.assertEqual("identity", bijector.name) x = [[[0.], [1.]]] self.assertAllEqual(x, bijector.forward(x).eval()) @@ -38,8 +38,8 @@ class IdentityBijectorTest(test.TestCase): def testScalarCongruency(self): with self.test_session(): - bijector = Identity() - assert_scalar_congruency( + bijector = identity_bijector.Identity() + bijector_test_util.assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/bijector_test_util.py b/tensorflow/python/ops/distributions/bijector_test_util.py similarity index 100% rename from tensorflow/contrib/distributions/python/ops/bijectors/bijector_test_util.py rename to tensorflow/python/ops/distributions/bijector_test_util.py diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/identity_impl.py b/tensorflow/python/ops/distributions/identity_bijector.py similarity index 100% rename from tensorflow/contrib/distributions/python/ops/bijectors/identity_impl.py rename to tensorflow/python/ops/distributions/identity_bijector.py diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py similarity index 99% rename from tensorflow/contrib/distributions/python/ops/transformed_distribution.py rename to tensorflow/python/ops/distributions/transformed_distribution.py index e146e20d3ac..09b26a9fb73 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/python/ops/distributions/transformed_distribution.py @@ -21,7 +21,6 @@ import numpy as np # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. -from tensorflow.contrib.distributions.python.ops.bijectors.identity import Identity from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,6 +31,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import identity_bijector from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -265,7 +265,7 @@ class TransformedDistribution(distribution_lib.Distribution): self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty") if bijector is None: - bijector = Identity(validate_args=validate_args) + bijector = identity_bijector.Identity(validate_args=validate_args) # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph From e98357a9fd8a2e0962b2da06d769e3e58aedffd4 Mon Sep 17 00:00:00 2001 From: Anna R Date: Thu, 4 May 2017 14:32:19 -0800 Subject: [PATCH 32/43] Internal change. Change: 155135670 --- tensorflow/tools/pip_package/pip_smoke_test.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index fa61a19b39f..4bb5c1b73c6 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -28,11 +28,13 @@ import subprocess PIP_PACKAGE_QUERY = """bazel query \ 'deps(//tensorflow/tools/pip_package:build_pip_package)'""" -PY_TEST_QUERY = """bazel query 'filter("^((?!(benchmark|manual|no_pip)).)*$", \ - deps(kind(py_test,\ - //tensorflow/python/... + \ - //tensorflow/tensorboard/... + \ - //tensorflow/contrib/...), 1))'""" +PY_TEST_QUERY = """bazel query 'deps(\ + filter("^((?!benchmark).)*$",\ + kind(py_test,\ + //tensorflow/python/... \ + + //tensorflow/tensorboard/... \ + + //tensorflow/contrib/... \ + - attr(tags, "manual|no_pip", //tensorflow/...))), 1)'""" # Hard-coded blacklist of files if not included in pip package # TODO(amitpatankar): Clean up blacklist. From f696b5d4398f8295fcb2be3c89a94e007ca62287 Mon Sep 17 00:00:00 2001 From: Anna R Date: Thu, 4 May 2017 14:38:05 -0800 Subject: [PATCH 33/43] Added traceback_with_start_lines property to op that includes function start line number as the last element in each traceback tuple. Change: 155136334 --- tensorflow/python/framework/ops.py | 35 ++++++++++++++----- tensorflow/python/framework/ops_test.py | 22 ++++++++++++ .../api/golden/tensorflow.-operation.pbtxt | 4 +++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 93a29d0d8e9..05972022d03 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -70,25 +70,33 @@ def _override_helper(clazz_object, operator, func): setattr(clazz_object, operator, func) -def _convert_stack(stack): +def _convert_stack(stack, include_func_start_lineno=False): """Converts a stack extracted using _extract_stack() to a traceback stack. Args: - stack: A list of n 4-tuples, (filename, lineno, name, frame_globals). + stack: A list of n 5-tuples, + (filename, lineno, name, frame_globals, func_start_lineno). + include_func_start_lineno: True if function start line number should be + included as the 5th entry in return tuples. Returns: - A list of n 4-tuples (filename, lineno, name, code), where the code tuple - element is calculated from the corresponding elements of the input tuple. + A list of n 4-tuples or 5-tuples + (filename, lineno, name, code, [optional: func_start_lineno]), where the + code tuple element is calculated from the corresponding elements of the + input tuple. """ ret = [] - for filename, lineno, name, frame_globals in stack: + for filename, lineno, name, frame_globals, func_start_lineno in stack: linecache.checkcache(filename) line = linecache.getline(filename, lineno, frame_globals) if line: line = line.strip() else: line = None - ret.append((filename, lineno, name, line)) + if include_func_start_lineno: + ret.append((filename, lineno, name, line, func_start_lineno)) + else: + ret.append((filename, lineno, name, line)) return ret @@ -103,7 +111,8 @@ def _extract_stack(): be formatted etc. using traceback methods. Returns: - A list of 4-tuples (filename, lineno, name, frame_globals) corresponding to + A list of 5-tuples + (filename, lineno, name, frame_globals, func_start_lineno) corresponding to the call stack of the current thread. """ # pylint: enable=line-too-long @@ -118,7 +127,8 @@ def _extract_stack(): filename = co.co_filename name = co.co_name frame_globals = f.f_globals - ret.append((filename, lineno, name, frame_globals)) + func_start_lineno = co.co_firstlineno + ret.append((filename, lineno, name, frame_globals, func_start_lineno)) f = f.f_back ret.reverse() return ret @@ -1505,6 +1515,15 @@ class Operation(object): """Returns the call stack from when this operation was constructed.""" return _convert_stack(self._traceback) + @property + def traceback_with_start_lines(self): + """Same as traceback but includes start line of function definition. + + Returns: + A list of 5-tuples (filename, lineno, name, code, func_start_lineno). + """ + return _convert_stack(self._traceback, include_func_start_lineno=True) + def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 06d03121a0f..3e9f047a7de 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -22,6 +22,7 @@ import gc import weakref from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op @@ -1703,5 +1704,26 @@ class NameScopeTest(test_util.TensorFlowTestCase): self.assertEqual("", g.get_name_scope()) +class TracebackTest(test_util.TensorFlowTestCase): + + def testTracebackWithStartLines(self): + with self.test_session() as sess: + a = constant_op.constant(2.0) + sess.run( + a, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue(sess.graph.get_operations()) + + # Tests that traceback_with_start_lines is the same as traceback + # but includes one more element at the end. + for op in sess.graph.get_operations(): + self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines)) + for frame, frame_with_start_line in zip( + op.traceback, op.traceback_with_start_lines): + self.assertEquals(5, len(frame_with_start_line)) + self.assertEquals(frame, frame_with_start_line[:-1]) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt b/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt index 0f43a49ee96..64240f70698 100644 --- a/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt @@ -38,6 +38,10 @@ tf_class { name: "traceback" mtype: "" } + member { + name: "traceback_with_start_lines" + mtype: "" + } member { name: "type" mtype: "" From 6604bc6fdc456c4028b3de41c1eda0ae89007630 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Thu, 4 May 2017 14:39:37 -0800 Subject: [PATCH 34/43] Clean up TensorBoard build and fix sync process Change: 155136555 --- tensorflow/BUILD | 42 +++++++++++++++++++ .../vz_data_summary/BUILD.OPENSOURCE | 34 --------------- .../components/vz_projector/BUILD.OPENSOURCE | 19 --------- tensorflow/tensorboard/defs.bzl | 16 +++++++ 4 files changed, 58 insertions(+), 53 deletions(-) delete mode 100644 tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE delete mode 100644 tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE diff --git a/tensorflow/BUILD b/tensorflow/BUILD index a2f7a9fb639..da8190523ab 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -318,6 +318,48 @@ filegroup( "//tensorflow/tensorboard/backend:all_files", "//tensorflow/tensorboard/backend/event_processing:all_files", "//tensorflow/tensorboard/components:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_backend:all_files", + "//tensorflow/tensorboard/components/tf_backend_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_color_scale:all_files", + "//tensorflow/tensorboard/components/tf_color_scale/demo:all_files", + "//tensorflow/tensorboard/components/tf_color_scale_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common/demo:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_distribution_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_globals:all_files", + "//tensorflow/tensorboard/components/tf_globals_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_common:all_files", + "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_histogram_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_image_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_image_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_imports:all_files", + "//tensorflow/tensorboard/components/tf_imports_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_storage:all_files", + "//tensorflow/tensorboard/components/tf_storage_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/vz_data_summary:all_files", + "//tensorflow/tensorboard/components/vz_distribution_chart:all_files", + "//tensorflow/tensorboard/components/vz_distribution_chart/demo:all_files", + "//tensorflow/tensorboard/components/vz_distribution_chart_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files", + "//tensorflow/tensorboard/components/vz_histogram_timeseries/demo:all_files", + "//tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_line_chart:all_files", + "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files", + "//tensorflow/tensorboard/components/vz_line_chart_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_projector:all_files", + "//tensorflow/tensorboard/components/vz_projector_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_sorting:all_files", + "//tensorflow/tensorboard/components/vz_sorting/test:all_files", + "//tensorflow/tensorboard/components/vz_sorting_d3v4:all_files", "//tensorflow/tensorboard/lib:all_files", "//tensorflow/tensorboard/plugins:all_files", "//tensorflow/tensorboard/plugins/projector:all_files", diff --git a/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE b/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE deleted file mode 100644 index 9743d70d947..00000000000 --- a/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -# Description: -# Package for the data-summary vz-element. -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE b/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE deleted file mode 100644 index 8c222be10e9..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE +++ /dev/null @@ -1,19 +0,0 @@ -# Description: -# Package for the Embedding Projector component. -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/tensorboard/defs.bzl b/tensorflow/tensorboard/defs.bzl index 7bb5f961c97..22bb047075f 100644 --- a/tensorflow/tensorboard/defs.bzl +++ b/tensorflow/tensorboard/defs.bzl @@ -60,6 +60,22 @@ def tensorboard_typescript_genrule(name, srcs, typings=[], **kwargs): **kwargs ) +def tensorboard_karma_web_test_suite(**kwargs): + """Rules referencing this will be deleted from the codebase soon.""" + pass + +def tensorboard_ts_declaration(**kwargs): + """Rules referencing this will be deleted from the codebase soon.""" + pass + +def tensorboard_ts_development_sources(**kwargs): + """Rules referencing this will be deleted from the codebase soon.""" + pass + +def tensorboard_ts_devserver(**kwargs): + """Rules referencing this will be deleted from the codebase soon.""" + pass + def tensorboard_ts_library(**kwargs): """Rules referencing this will be deleted from the codebase soon.""" pass From 4c992d9d6d45cc2e23f20ee728880fc236901dd4 Mon Sep 17 00:00:00 2001 From: Yutaka Leon Date: Thu, 4 May 2017 15:08:49 -0800 Subject: [PATCH 35/43] Fix mac build. Change: 155140054 --- tensorflow/contrib/cmake/tf_core_ops.cmake | 1 + tensorflow/contrib/cmake/tf_python.cmake | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 2a19433a7b2..eae00ab8756 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -22,6 +22,7 @@ set(tf_op_lib_names "image_ops" "io_ops" "linalg_ops" + "lookup_ops" "logging_ops" "math_ops" "nn_ops" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index ad3b29c8ea5..9e2eb71b4c2 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -597,6 +597,7 @@ GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") GENERATE_PYTHON_OP_LIB("logging_ops") +GENERATE_PYTHON_OP_LIB("lookup_ops") GENERATE_PYTHON_OP_LIB("nn_ops") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") From 7ff483746404a3ca7ed66ae4271b21bcb07082ee Mon Sep 17 00:00:00 2001 From: Yutaka Leon Date: Thu, 4 May 2017 15:09:27 -0800 Subject: [PATCH 36/43] Fix documentation in supervisor. Change: 155140112 --- tensorflow/docs_src/programmers_guide/supervisor.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/docs_src/programmers_guide/supervisor.md b/tensorflow/docs_src/programmers_guide/supervisor.md index 82ed1c2cf76..55a090df589 100644 --- a/tensorflow/docs_src/programmers_guide/supervisor.md +++ b/tensorflow/docs_src/programmers_guide/supervisor.md @@ -362,8 +362,8 @@ following keyword arguments to the `Supervisor()` constructor: If not specified, the supervisor uses the first op in the `tf.GraphKeys.LOCAL_INIT_OP` collection. If the collection is empty the supervisor adds an op to initialize all the tables and local variables in - the graph by calling `tf.initialize_all_tables()` and - `tf.initialize_all_local_variables()`. + the graph by calling `tf.tables_initializer()` and + `tf.local_variables_initializer()`. Pass `None` to not use a local init op. From ffc5e6fbf68b5229b68cfe96bbcaf58619277c06 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Thu, 4 May 2017 15:35:27 -0800 Subject: [PATCH 37/43] [TF] optimization: SparseTensorDenseMatMul GPU kernel rewritten in pure cuda. Also added an additional GPU int32::max check that was missing. Performance seems to be between 1x-10x faster on average. The likely culprit on CPU slowdown was probably the unnecessary temp allocation for scratch space. Performance on a k40, compiled -c opt --config cuda --copt=-mavx: **BEFORE** Matrix sizes: A sparse [m, k] with % nonzero values between 1% and 80% B dense [k, n] % nnz n gpu m k dt(dense) dt(sparse) dt(sparse)/dt(dense) 0.01 50 True 100 100 0.000319954 0.000275495 0.861045 0.01 50 True 100 1000 0.000469565 0.000290895 0.619497 0.01 50 True 1000 100 0.000572815 0.000271131 0.473331 0.01 50 True 1000 1000 0.00133119 0.00042006 0.315554 0.01 50 False 100 100 0.00034191 0.000289171 0.845751 0.01 50 False 100 1000 0.0004796 0.00028483 0.593891 0.01 50 False 1000 100 0.000632371 0.000300461 0.475134 0.01 50 False 1000 1000 0.00134726 0.000576285 0.427746 0.01 100 True 100 100 0.000353755 0.00027729 0.783849 0.01 100 True 100 1000 0.000536649 0.00028337 0.528036 0.01 100 True 1000 100 0.000661941 0.00027933 0.421987 0.01 100 True 1000 1000 0.0014109 0.0006698 0.474732 0.01 100 False 100 100 0.00039546 0.00030159 0.762631 0.01 100 False 100 1000 0.00054909 0.00027276 0.49675 0.01 100 False 1000 100 0.000631344 0.00028231 0.447157 0.01 100 False 1000 1000 0.00141789 0.000657049 0.463398 0.2 50 True 100 100 0.00033689 0.000280155 0.831591 0.2 50 True 100 1000 0.000563495 0.00064159 1.13859 0.2 50 True 1000 100 0.00058635 0.00067611 1.15308 0.2 50 True 1000 1000 0.00153552 0.00486242 3.16662 0.2 50 False 100 100 0.000333545 0.000267555 0.802154 0.2 50 False 100 1000 0.000544 0.00066272 1.21824 0.2 50 False 1000 100 0.00058253 0.000670955 1.15179 0.2 50 False 1000 1000 0.00153017 0.00480928 3.14298 0.2 100 True 100 100 0.00036919 0.000288659 0.781872 0.2 100 True 100 1000 0.00067063 0.00110059 1.64113 0.2 100 True 1000 100 0.00066443 0.00108547 1.63369 0.2 100 True 1000 1000 0.00180991 0.00961579 5.31286 0.2 100 False 100 100 0.00040061 0.000325365 0.812174 0.2 100 False 100 1000 0.00066774 0.00111843 1.67494 0.2 100 False 1000 100 0.000696205 0.00108078 1.55239 0.2 100 False 1000 1000 0.00179788 0.00960569 5.34278 0.5 50 True 100 100 0.00034819 0.00033425 0.959963 0.5 50 True 100 1000 0.00075176 0.00134084 1.78359 0.5 50 True 1000 100 0.000642445 0.00133641 2.08019 0.5 50 True 1000 1000 0.00233791 0.0124282 5.31597 0.5 50 False 100 100 0.000345069 0.000334586 0.96962 0.5 50 False 100 1000 0.00071701 0.00135879 1.89508 0.5 50 False 1000 100 0.000632119 0.00134036 2.12043 0.5 50 False 1000 1000 0.00240216 0.0126202 5.25368 0.5 100 True 100 100 0.000393934 0.00040344 1.02413 0.5 100 True 100 1000 0.000957675 0.002709 2.82873 0.5 100 True 1000 100 0.000756125 0.00242428 3.20619 0.5 100 True 1000 1000 0.00298202 0.0241416 8.09572 0.5 100 False 100 100 0.000395606 0.000433675 1.09623 0.5 100 False 100 1000 0.000963565 0.00248293 2.57682 0.5 100 False 1000 100 0.00079523 0.0024281 3.05333 0.5 100 False 1000 1000 0.00299668 0.0242615 8.09614 0.8 50 True 100 100 0.00036806 0.00040923 1.11186 0.8 50 True 100 1000 0.00091419 0.00207383 2.26848 0.8 50 True 1000 100 0.000684329 0.00196612 2.87307 0.8 50 True 1000 1000 0.00302433 0.0199798 6.60637 0.8 50 False 100 100 0.000368149 0.000615025 1.67058 0.8 50 False 100 1000 0.0008786 0.00205821 2.3426 0.8 50 False 1000 100 0.00067889 0.00195498 2.87967 0.8 50 False 1000 1000 0.00290009 0.0191242 6.59434 0.8 100 True 100 100 0.000452549 0.00063767 1.40906 0.8 100 True 100 1000 0.00126929 0.00391422 3.08378 0.8 100 True 1000 100 0.000919235 0.00386167 4.20096 0.8 100 True 1000 1000 0.00423295 0.0431824 10.2015 0.8 100 False 100 100 0.000428261 0.000626891 1.46381 0.8 100 False 100 1000 0.00120801 0.00395877 3.27711 0.8 100 False 1000 100 0.00080466 0.00385143 4.78641 0.8 100 False 1000 1000 0.00370808 0.0403527 10.8824 **AFTER** Matrix sizes: A sparse [m, k] with % nonzero values between 1% and 80% B dense [k, n] % nnz n gpu m k dt(dense) dt(sparse) dt(sparse)/dt(dense) 0.01 50 True 100 100 0.000312485 0.00020528 0.656927 0.01 50 True 100 1000 0.0004655 0.00020095 0.431686 0.01 50 True 1000 100 0.000567449 0.000203935 0.359389 0.01 50 True 1000 1000 0.00132323 0.00027171 0.205339 0.01 50 False 100 100 0.000319945 0.000197511 0.617328 0.01 50 False 100 1000 0.000466419 0.000210185 0.450635 0.01 50 False 1000 100 0.0005581 0.000199865 0.358117 0.01 50 False 1000 1000 0.00129479 0.000451496 0.348702 0.01 100 True 100 100 0.000364131 0.000196835 0.540561 0.01 100 True 100 1000 0.00053398 0.000206494 0.386708 0.01 100 True 1000 100 0.00062722 0.000203185 0.323946 0.01 100 True 1000 1000 0.00138674 0.000335904 0.242227 0.01 100 False 100 100 0.000361339 0.000195 0.53966 0.01 100 False 100 1000 0.000531831 0.000207155 0.389513 0.01 100 False 1000 100 0.00062245 0.000197015 0.316515 0.01 100 False 1000 1000 0.0014007 0.000328825 0.234757 0.2 50 True 100 100 0.00033185 0.000262895 0.792209 0.2 50 True 100 1000 0.00054391 0.000586189 1.07773 0.2 50 True 1000 100 0.000581805 0.000531535 0.913597 0.2 50 True 1000 1000 0.00153913 0.00142783 0.927687 0.2 50 False 100 100 0.00033572 0.000266831 0.794803 0.2 50 False 100 1000 0.000534315 0.000585151 1.09514 0.2 50 False 1000 100 0.000580961 0.00033344 0.573947 0.2 50 False 1000 1000 0.0015055 0.00143968 0.956284 0.2 100 True 100 100 0.000371666 0.00026337 0.708621 0.2 100 True 100 1000 0.000667235 0.00056811 0.851439 0.2 100 True 1000 100 0.000671356 0.000400575 0.596666 0.2 100 True 1000 1000 0.00178568 0.00250393 1.40222 0.2 100 False 100 100 0.000370425 0.000254935 0.688223 0.2 100 False 100 1000 0.000661175 0.000601134 0.909191 0.2 100 False 1000 100 0.0006944 0.00039817 0.573401 0.2 100 False 1000 1000 0.00176969 0.0024947 1.40968 0.5 50 True 100 100 0.000346885 0.000263295 0.759028 0.5 50 True 100 1000 0.00073113 0.00107669 1.47263 0.5 50 True 1000 100 0.000672774 0.000493085 0.732914 0.5 50 True 1000 1000 0.00260436 0.003335 1.28054 0.5 50 False 100 100 0.00036242 0.000273196 0.753809 0.5 50 False 100 1000 0.000753295 0.00107086 1.42157 0.5 50 False 1000 100 0.00064886 0.000501654 0.773132 0.5 50 False 1000 1000 0.00241105 0.0033146 1.37475 0.5 100 True 100 100 0.000401269 0.00027831 0.693573 0.5 100 True 100 1000 0.00094245 0.00111468 1.18275 0.5 100 True 1000 100 0.00075719 0.00074962 0.990003 0.5 100 True 1000 1000 0.00297528 0.00601445 2.02147 0.5 100 False 100 100 0.000408576 0.00026246 0.642377 0.5 100 False 100 1000 0.00094272 0.00112762 1.19613 0.5 100 False 1000 100 0.000762925 0.00074343 0.974446 0.5 100 False 1000 1000 0.00314936 0.00604122 1.91824 0.8 50 True 100 100 0.00036589 0.000331376 0.905669 0.8 50 True 100 1000 0.00086403 0.00171248 1.98197 0.8 50 True 1000 100 0.00067048 0.000715261 1.06679 0.8 50 True 1000 1000 0.00284684 0.00527865 1.85422 0.8 50 False 100 100 0.000357161 0.000540144 1.51233 0.8 50 False 100 1000 0.000884765 0.00170428 1.92625 0.8 50 False 1000 100 0.000666975 0.000737065 1.10509 0.8 50 False 1000 1000 0.0028149 0.00530442 1.88441 0.8 100 True 100 100 0.00041237 0.00034323 0.832335 0.8 100 True 100 1000 0.00122102 0.00179725 1.47192 0.8 100 True 1000 100 0.000807976 0.00111246 1.37684 0.8 100 True 1000 1000 0.00379081 0.00968211 2.5541 0.8 100 False 100 100 0.000426315 0.000339085 0.795386 0.8 100 False 100 1000 0.00144096 0.00179819 1.2479 0.8 100 False 1000 100 0.000951196 0.0011155 1.17274 0.8 100 False 1000 1000 0.0039524 0.00980128 2.47983 Change: 155142876 --- .../kernels/sparse_tensor_dense_matmul_op.cc | 65 ++++---- .../kernels/sparse_tensor_dense_matmul_op.h | 9 +- .../sparse_tensor_dense_matmul_op_gpu.cu.cc | 144 ++++++------------ .../sparse_tensor_dense_matmul_op_test.py | 87 ++++++++--- 4 files changed, 145 insertions(+), 160 deletions(-) diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 30026f222a6..30c57ef287f 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -65,7 +65,8 @@ class SparseTensorDenseMatMulOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()), errors::InvalidArgument("Tensor 'a_indices' is not a matrix")); - OP_REQUIRES(ctx, a_indices->shape().dim_size(0) == a_values->NumElements(), + const int64 nnz = a_indices->shape().dim_size(0); + OP_REQUIRES(ctx, nnz == a_values->NumElements(), errors::InvalidArgument("Number of rows of a_indices does not " "match number of entries in a_values")); @@ -89,8 +90,28 @@ class SparseTensorDenseMatMulOp : public OpKernel { inner_left, " vs. ", inner_right, ". Did you forget a transpose? " "Dimensions of A: [", - a_shape_t(0), ", ", a_shape_t(1), "). Dimensions of B: ", - b->shape().DebugString())); + a_shape_t(0), ", ", a_shape_t(1), + "). Dimensions of B: ", b->shape().DebugString())); + + if (std::is_same::value) { + // The GPU implementation is optimized to use 32 bit indexing, so + // give a friendly error to the programmer early on if they + // exceed. + const int int32max = std::numeric_limits::max(); + OP_REQUIRES( + ctx, + (FastBoundsCheck(inner_left, int32max) && + FastBoundsCheck(inner_right, int32max) && + FastBoundsCheck(outer_left, int32max) && + FastBoundsCheck(outer_right, int32max) && + FastBoundsCheck(b->NumElements(), int32max) && + FastBoundsCheck(outer_left * outer_right, int32max) && + FastBoundsCheck(a_values->NumElements(), int32max)), + errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs")); + OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max), + errors::InvalidArgument( + "Cannot use GPU when output.shape[1] * nnz(a) > 2^31")); + } TensorShape out_shape({outer_left, outer_right}); Tensor* out = nullptr; @@ -111,41 +132,13 @@ class SparseTensorDenseMatMulOp : public OpKernel { return; } - Tensor scratch; - - if (std::is_same::value) { - // The GPU implementation is optimized to use 32 bit indexing, so - // give a friendly error to the programmer early on if they exceed. - OP_REQUIRES( - ctx, - FastBoundsCheck(inner_left, std::numeric_limits::max()) && - FastBoundsCheck(inner_right, std::numeric_limits::max()) && - FastBoundsCheck(outer_left, std::numeric_limits::max()) && - FastBoundsCheck(outer_right, std::numeric_limits::max()) && - FastBoundsCheck(b->NumElements(), - std::numeric_limits::max()) && - FastBoundsCheck(out->NumElements(), - std::numeric_limits::max()) && - FastBoundsCheck(a_values->NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs")); - const int nnz = static_cast(a_values->NumElements()); - // Need nnz length vec scratch space on the GPU. - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, - TensorShape({nnz}), &scratch)); - } else { - // We don't need scratch space on the CPU. - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, - TensorShape({0}), &scratch)); - } - #define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ Device, T, Tindices, ADJ_A, \ ADJ_B>::Compute(ctx->eigen_device(), out->matrix(), \ a_indices->matrix(), a_values->vec(), \ - b->matrix(), scratch.vec()); \ + b->matrix()); \ OP_REQUIRES_OK(ctx, functor_status); \ } @@ -189,10 +182,9 @@ namespace functor { Status SparseTensorDenseMatMulFunctor< \ GPUDevice, T, Tindices, ADJ_A, \ ADJ_B>::Compute(const GPUDevice& d, typename TTypes::Matrix out, \ - typename TTypes::ConstMatrix a_indices, \ + TTypes::ConstMatrix a_indices, \ typename TTypes::ConstVec a_values, \ - typename TTypes::ConstMatrix b, \ - typename TTypes::Vec scratch); \ + typename TTypes::ConstMatrix b); \ extern template struct SparseTensorDenseMatMulFunctor< \ GPUDevice, T, Tindices, ADJ_A, ADJ_B>; @@ -255,8 +247,7 @@ struct SparseTensorDenseMatMulFunctor { static Status Compute(const CPUDevice& d, typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, - typename TTypes::Vec scratch) { + typename TTypes::ConstMatrix b) { const std::size_t nnz = a_values.size(); const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1)); const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0)); diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h index e707743f782..da131904949 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h @@ -28,11 +28,10 @@ namespace functor { template struct SparseTensorDenseMatMulFunctor { - static EIGEN_ALWAYS_INLINE Status - Compute(const Device& d, typename TTypes::Matrix out, - typename TTypes::ConstMatrix a_indices, - typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, typename TTypes::Vec scratch); + static EIGEN_ALWAYS_INLINE Status Compute( + const Device& d, typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, typename TTypes::ConstMatrix b); }; template diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc index 7266e0cf812..e261e42e0d3 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc @@ -20,71 +20,45 @@ limitations under the License. #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -namespace generator { - template -class SparseTensorDenseMatMulGPUGenerator { - public: - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator( - typename TTypes::Tensor32Bit out, - typename TTypes::Tensor32Bit a_indices, - typename TTypes::Tensor32Bit a_values, - typename TTypes::Tensor32Bit b) - : out_(out), - lhs_index_a_(ADJ_A ? 1 : 0), - rhs_index_a_(ADJ_A ? 0 : 1), - a_indices_(a_indices), - a_values_(a_values), - lhs_right_size(ADJ_B ? b.dimension(1) : b.dimension(0)), - maybe_adjoint_b_( - functor::MaybeAdjoint::Tensor32Bit, - ADJ_B>(b)) {} - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T - operator()(const Eigen::array& j_and_ix) const { -#ifdef __CUDA_ARCH__ - const int j = j_and_ix[0]; - const int ix = j_and_ix[1]; - int m = a_indices_(ix, lhs_index_a_); - int k = a_indices_(ix, rhs_index_a_); - assert(k < lhs_right_size); - assert(m < out_.dimension(0)); - // If asserts are disabled, the caller is violating the sparse - // tensor index contract, and so we return invalid results. - // Force returning NaNs to try to signal that something is amiss. - T b_value; - if (k >= lhs_right_size || m >= out_.dimension(0)) { - m = 0; - k = 0; - b_value = std::numeric_limits::quiet_NaN(); - } else { - b_value = maybe_adjoint_b_(k, j); +__global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows, + int b_cols, int p, + const Tindices* a_indices, + const T* a_values, const T* b, + T* out) { + // out_{ij} = sum_k {a_ik b_kj} + // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk} + const int n = (ADJ_B) ? b_cols : b_rows; + CUDA_1D_KERNEL_LOOP(index, nnz * p) { + const int a_ix = index / p; + const int j = index % p; + const int i = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 1 : 0)); + const int k = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 0 : 1)); + if (!FastBoundsCheck(i, m)) { + continue; // Nowhere to signal an error :( } - atomicAdd(&out_(m, j), a_values_(ix) * b_value); -#else - assert(false && "This should only be run on the device"); -#endif - // Return something - return T(0); + // out[i, j] + T* out_location = out + i * p + j; + if (!FastBoundsCheck(k, n)) { + CudaAtomicAdd(out_location, std::numeric_limits::quiet_NaN()); + continue; + } + + // a_value == (ADJ_A) ? a[k, i] : a[i, k] + const T a_value = ldg(a_values + a_ix); + + // b_value == (ADJ_B) ? b[j, k] : b[k, j] + const T b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j)); + CudaAtomicAdd(out_location, a_value * b_value); } - - private: - mutable typename TTypes::Tensor32Bit out_; - const int lhs_index_a_; - const int rhs_index_a_; - typename TTypes::Tensor32Bit a_indices_; - typename TTypes::Tensor32Bit a_values_; - const int lhs_right_size; - functor::MaybeAdjoint::Tensor32Bit, ADJ_B> - maybe_adjoint_b_; -}; - -} // namespace generator +} namespace functor { @@ -94,51 +68,23 @@ struct SparseTensorDenseMatMulFunctor { Compute(const GPUDevice& d, typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, typename TTypes::Vec scratch) { - generator::SparseTensorDenseMatMulGPUGenerator - sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices), - To32Bit(a_values), To32Bit(b)); - To32Bit(out).device(d) = To32Bit(out).constant(T(0)); + typename TTypes::ConstMatrix b) { + out.device(d) = out.constant(T(0)); int nnz = a_values.size(); - int n = (ADJ_B) ? b.dimension(0) : b.dimension(1); + // out = A * B, A is [m x n] and B is [n x p], out is [m x p] + int m = out.dimension(0); + int p = out.dimension(1); + int b_rows = b.dimension(0); + int b_cols = b.dimension(1); -#if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::Tensor::Dimensions matrix_1_by_nnz{{ 1, nnz }}; - Eigen::array n_by_1{{ n, 1 }}; - Eigen::array reduce_on_rows{{ 0 }}; -#else - Eigen::IndexList, int> matrix_1_by_nnz; - matrix_1_by_nnz.set(1, nnz); - Eigen::IndexList > n_by_1; - n_by_1.set(0, n); - Eigen::IndexList > reduce_on_rows; -#endif + // TODO(ebrevdo): Should this be alpha * nnz instead of + // out.size()? Perhaps p * nnz ? + CudaLaunchConfig config = GetCudaLaunchConfig(p * nnz, d); - // How this works: the generator iterates over (j, ix) where j - // iterates from 0 .. n - 1 and ix iterates from - // 0 .. nnz - 1. A side effect of the generator is to accumulate - // the products of values in A and B into the appropriate location - // in the dense matrix out. In order to run the iteration, - // we take a smaller variable and broadcast to a size (n, nnz). - // This is the scratch variable. In order to enforce execution, - // we have to perform assignment back into scratch (taking the sum). - // We don't care what gets assigned to scratch - only the side effect - // of the execution in the generator. - // - // Note it's not sufficient that scratch be a scalar, and to - // broadcast it to a matrix. Eigen splits the computation not - // based on the largest intermediate shape (the size of the - // broadcast of scratch) but based on the output shape. So - // scratch needs to be a vector at least. - // - // Note also that only float type is supported because the - // atomicAdd operation is only supported for floats in hardware. - To32Bit(scratch).device(d) = - To32Bit(scratch) - .reshape(matrix_1_by_nnz) - .broadcast(n_by_1) - .generate(sparse_tensor_dense_matmul_generator) - .sum(reduce_on_rows); + SparseTensorDenseMatMulKernel + <<>>( + nnz, m, b_rows, b_cols, p, a_indices.data(), a_values.data(), + b.data(), out.data()); return Status::OK(); } diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py index 80991751860..a0bd178e247 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -161,6 +162,46 @@ class SparseTensorDenseMatMulTest(test.TestCase): sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t, adjoint_a=True).eval() + def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self): + # Note: use_gpu=False because nice errors are only returned from CPU kerne + if not test.is_gpu_available(): + return + with self.test_session(use_gpu=True): + indices = np.array([[1, 10]]).astype(np.int64) + values = np.array([10]).astype(np.float32) + shape = [3, 2] + sparse_t = sparse_tensor.SparseTensor(indices, values, shape) + + # Test multiplying by both a small and large dense matrix, to hit + # both cases in the kernel. + dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) + expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t).eval()) + dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) + expected_t = np.array( + [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t).eval()) + + # Repeat with adjoint_a, now the error is that the sparse index + # is OOO w.r.t. the output. The GPU kernel can't do much here, + # so it just doesn't accumulate. + + dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) + expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True).eval()) + + dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) + expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True).eval()) + # Tests setting one dimension to be a high value. def _testLarge(self, np_dtype): r1 = np.random.randint(6000, 20000) @@ -175,9 +216,12 @@ class SparseTensorDenseMatMulTest(test.TestCase): y = _maybe_complex(np.random.randn(k, n).astype(np_dtype)) - self._testMatmul(x, y) + self._testMatmul(x, y, adjoint_a=False, adjoint_b=False) + self._testMatmul(x.transpose(), y, adjoint_a=True, adjoint_b=False) + self._testMatmul(x, y.transpose(), adjoint_a=False, adjoint_b=True) + self._testMatmul( + x.transpose(), y.transpose(), adjoint_a=True, adjoint_b=True) - def testLarge(self): np.random.seed(127) # Repeatable results self._testLarge(np.float32) self._testLarge(np.float64) @@ -221,7 +265,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x, y, adjoint_a, lambda t, _: t < iterations, body, (t0, v0), parallel_iterations=1, - back_prop=False) + back_prop=False, + shape_invariants=(tensor_shape.TensorShape(()), + tensor_shape.TensorShape(None))) return [final] return _timeit @@ -246,7 +292,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(x_ind, x_val, x_shape, lambda t, _: t < iterations, body, (t0, v0), parallel_iterations=1, - back_prop=False) + back_prop=False, + shape_invariants=(tensor_shape.TensorShape(()), + tensor_shape.TensorShape(None))) return [final] return _timeit @@ -291,7 +339,7 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, if skip_dense: delta_dense = float("nan") else: - with session.Session("", config=config, graph=ops.Graph()) as sess: + with session.Session(config=config, graph=ops.Graph()) as sess: if not use_gpu: with ops.device("/cpu:0"): x_t = constant_op.constant(x) @@ -299,12 +347,12 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense( x_t, y_t, adjoint_a, adjoint_b) else: - x_t = constant_op.constant(x) - y_t = constant_op.constant(y) - ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x_t, y_t, - adjoint_a, - adjoint_b) - delta_dense = _timer(sess, ops_fn, 1000) + with ops.device("/gpu:0"): + x_t = constant_op.constant(x) + y_t = constant_op.constant(y) + ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense( + x_t, y_t, adjoint_a, adjoint_b) + delta_dense = _timer(sess, ops_fn, 200) # Using sparse_tensor_dense_matmul. with session.Session("", config=config, graph=ops.Graph()) as sess: @@ -317,13 +365,14 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) else: - x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T) - x_val = constant_op.constant(x[np.where(x)]) - x_shape = constant_op.constant(np.array(x.shape).astype(np.int64)) - y_t = constant_op.constant(y) - ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( - x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) - delta_sparse = _timer(sess, ops_fn, 1000) + with ops.device("/gpu:0"): + x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T) + x_val = constant_op.constant(x[np.where(x)]) + x_shape = constant_op.constant(np.array(x.shape).astype(np.int64)) + y_t = constant_op.constant(y) + ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( + x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) + delta_sparse = _timer(sess, ops_fn, 200) print("%g \t %d \t %s \t %d \t %d \t %g \t %g \t %g" % (1 - thresh, n, use_gpu, m, k, delta_dense, delta_sparse, @@ -340,7 +389,7 @@ def main(_): "\t dt(sparse)/dt(dense)") for thresh in (0.99, 0.8, 0.5, 0.2): - for n in (1, 10, 25): + for n in (50, 100): for use_gpu in (True, False): for m in (100, 1000): for k in (100, 1000): From 81fbb1246c3a41dec826fd9958772c02f88982bb Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Thu, 4 May 2017 16:14:18 -0800 Subject: [PATCH 38/43] Include c++ gradients in c_api build rule. #6268 #9150 Change: 155146664 --- tensorflow/c/BUILD | 1 + tensorflow/cc/BUILD | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 4ad69ae3fbd..3ab4e8efcdb 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -58,6 +58,7 @@ tf_cuda_library( "//tensorflow/cc/saved_model:loader", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", + "//tensorflow/cc:grad_ops", "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 8810b8731ae..8d4260a0b9c 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -91,6 +91,7 @@ cc_library( deps = [ ":array_grad", ":math_grad", + ":nn_grad", ], ) From d48f3a9a3fcd91c6b97644f4007d92f17516b80a Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Thu, 4 May 2017 18:28:35 -0800 Subject: [PATCH 39/43] Automated rollback of change 155136555 Change: 155156366 --- tensorflow/BUILD | 42 ------------------- .../vz_data_summary/BUILD.OPENSOURCE | 34 +++++++++++++++ .../components/vz_projector/BUILD.OPENSOURCE | 19 +++++++++ tensorflow/tensorboard/defs.bzl | 16 ------- 4 files changed, 53 insertions(+), 58 deletions(-) create mode 100644 tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE create mode 100644 tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE diff --git a/tensorflow/BUILD b/tensorflow/BUILD index da8190523ab..a2f7a9fb639 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -318,48 +318,6 @@ filegroup( "//tensorflow/tensorboard/backend:all_files", "//tensorflow/tensorboard/backend/event_processing:all_files", "//tensorflow/tensorboard/components:all_files", - "//tensorflow/tensorboard/components/tf_audio_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_audio_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_backend:all_files", - "//tensorflow/tensorboard/components/tf_backend_d3v4:all_files", - "//tensorflow/tensorboard/components/tf_color_scale:all_files", - "//tensorflow/tensorboard/components/tf_color_scale/demo:all_files", - "//tensorflow/tensorboard/components/tf_color_scale_d3v4:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common/demo:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4:all_files", - "//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_distribution_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_globals:all_files", - "//tensorflow/tensorboard/components/tf_globals_d3v4:all_files", - "//tensorflow/tensorboard/components/tf_graph_common:all_files", - "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_histogram_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_image_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_image_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_imports:all_files", - "//tensorflow/tensorboard/components/tf_imports_d3v4:all_files", - "//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_storage:all_files", - "//tensorflow/tensorboard/components/tf_storage_d3v4:all_files", - "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_text_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/vz_data_summary:all_files", - "//tensorflow/tensorboard/components/vz_distribution_chart:all_files", - "//tensorflow/tensorboard/components/vz_distribution_chart/demo:all_files", - "//tensorflow/tensorboard/components/vz_distribution_chart_d3v4:all_files", - "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files", - "//tensorflow/tensorboard/components/vz_histogram_timeseries/demo:all_files", - "//tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4:all_files", - "//tensorflow/tensorboard/components/vz_line_chart:all_files", - "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files", - "//tensorflow/tensorboard/components/vz_line_chart_d3v4:all_files", - "//tensorflow/tensorboard/components/vz_projector:all_files", - "//tensorflow/tensorboard/components/vz_projector_d3v4:all_files", - "//tensorflow/tensorboard/components/vz_sorting:all_files", - "//tensorflow/tensorboard/components/vz_sorting/test:all_files", - "//tensorflow/tensorboard/components/vz_sorting_d3v4:all_files", "//tensorflow/tensorboard/lib:all_files", "//tensorflow/tensorboard/plugins:all_files", "//tensorflow/tensorboard/plugins/projector:all_files", diff --git a/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE b/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE new file mode 100644 index 00000000000..9743d70d947 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE @@ -0,0 +1,34 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# Description: +# Package for the data-summary vz-element. +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE b/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE new file mode 100644 index 00000000000..8c222be10e9 --- /dev/null +++ b/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE @@ -0,0 +1,19 @@ +# Description: +# Package for the Embedding Projector component. +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tensorboard/defs.bzl b/tensorflow/tensorboard/defs.bzl index 22bb047075f..7bb5f961c97 100644 --- a/tensorflow/tensorboard/defs.bzl +++ b/tensorflow/tensorboard/defs.bzl @@ -60,22 +60,6 @@ def tensorboard_typescript_genrule(name, srcs, typings=[], **kwargs): **kwargs ) -def tensorboard_karma_web_test_suite(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_declaration(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_development_sources(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - -def tensorboard_ts_devserver(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass - def tensorboard_ts_library(**kwargs): """Rules referencing this will be deleted from the codebase soon.""" pass From afd69fc26f85782dd6ac44ef1e05ff0d147399a9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 May 2017 19:03:38 -0800 Subject: [PATCH 40/43] Add `categorical_column_with_vocabulary_list`. Change: 155158042 --- .../python/feature_column/feature_column.py | 174 +++++++-- .../feature_column/feature_column_test.py | 334 ++++++++++++++++-- 2 files changed, 464 insertions(+), 44 deletions(-) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 33bed3abcf1..ffdf8868e21 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -121,6 +121,8 @@ from __future__ import print_function import abc import collections +import numpy as np + from tensorflow.python.feature_column import lookup_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -433,6 +435,12 @@ def bucketized_column(source_column, boundaries): return _BucketizedColumn(source_column, tuple(boundaries)) +def _assert_string_or_int(dtype, prefix): + if (dtype != dtypes.string) and (not dtype.is_integer): + raise ValueError( + '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) + + def categorical_column_with_hash_bucket(key, hash_bucket_size, dtype=dtypes.string): @@ -475,9 +483,7 @@ def categorical_column_with_hash_bucket(key, 'hash_bucket_size: {}, key: {}'.format( hash_bucket_size, key)) - if dtype != dtypes.string and not dtype.is_integer: - raise ValueError('dtype must be string or integer. ' - 'dtype: {}, column_name: {}'.format(dtype, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) return _HashedCategoricalColumn(key, hash_bucket_size, dtype) @@ -485,7 +491,7 @@ def categorical_column_with_hash_bucket(key, def categorical_column_with_vocabulary_file( key, vocabulary_file, vocabulary_size, num_oov_buckets=0, default_value=None, dtype=dtypes.string): - """Creates a `_CategoricalColumn` with vocabulary file configuration. + """A `_CategoricalColumn` with a vocabulary file. Use this when your inputs are in string or integer format, and you have a vocabulary file that maps each value to an integer ID. By default, @@ -504,7 +510,7 @@ def categorical_column_with_vocabulary_file( ID 50-54. ```python states = categorical_column_with_vocabulary_file( - key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=50, + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, num_oov_buckets=5) linear_prediction = make_linear_model(features, [states, ...]) ``` @@ -516,7 +522,7 @@ def categorical_column_with_vocabulary_file( others are assigned the corresponding line number 1-50. ```python states = categorical_column_with_vocabulary_file( - key='keywords', vocabulary_file='/us/states.txt', vocabulary_size=51, + key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, default_value=0) linear_prediction, _, _ = make_linear_model(features, [states, ...]) @@ -530,7 +536,9 @@ def categorical_column_with_vocabulary_file( column name and the dictionary key for feature parsing configs, feature `Tensor` objects, and feature columns. vocabulary_file: The vocabulary file name. - vocabulary_size: Number of the elements in the vocabulary. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. num_oov_buckets: Non-negative integer, the number of out-of-vocabulary buckets. All out-of-vocabulary inputs will be assigned IDs in the range `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of @@ -542,7 +550,7 @@ def categorical_column_with_vocabulary_file( dtype: The type of features. Only string and integer types are supported. Returns: - A `_CategoricalColumn` with vocabulary file configuration. + A `_CategoricalColumn` with a vocabulary file. Raises: ValueError: `vocabulary_file` is missing. @@ -564,9 +572,8 @@ def categorical_column_with_vocabulary_file( if num_oov_buckets < 0: raise ValueError('Invalid num_oov_buckets {} in {}.'.format( num_oov_buckets, key)) - if dtype != dtypes.string and not dtype.is_integer: - raise ValueError('Invalid dtype {} in {}.'.format(dtype, key)) - return _VocabularyCategoricalColumn( + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + return _VocabularyFileCategoricalColumn( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -575,6 +582,80 @@ def categorical_column_with_vocabulary_file( dtype=dtype) +def categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1): + """A `_CategoricalColumn` with in-memory vocabulary. + + Logic for feature f is: + id = f in vocabulary_list ? vocabulary_list.index(f) : default_value + + Use this when your inputs are in string or integer format, and you have an + in-memory vocabulary mapping each value to an integer ID. By default, + out-of-vocabulary values are ignored. Use `default_value` to specify how to + include out-of-vocabulary values. + + Inputs can be either `Tensor` or `SparseTensor`. If `Tensor`, missing values + can be represented by `-1` for int and `''` for string. Note that these values + are independent of the `default_value` argument. + + In the following examples, each input in `vocabulary_list` is assigned an ID + 0-4 corresponding to its index (e.g., input 'B' produces output 2). All other + inputs are assigned `default_value` 0. + + Linear model: + ```python + colors = categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0) + linear_prediction, _, _ = make_linear_model(features, [colors, ...]) + ``` + + Embedding for a DNN model: + ```python + dense_tensor = make_input_layer(features, [embedding_column(colors, 3),...]) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The value to use for values not in `vocabulary_list`. + + Returns: + A `_CategoricalColumn` with in-memory vocabulary. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: if `dtype` is not integer or string. + """ + if (vocabulary_list is None) or (len(vocabulary_list) < 1): + raise ValueError( + 'vocabulary_list {} must be non-empty, column_name: {}'.format( + vocabulary_list, key)) + if len(set(vocabulary_list)) != len(vocabulary_list): + raise ValueError( + 'Duplicate keys in vocabulary_list {}, column_name: {}'.format( + vocabulary_list, key)) + vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype) + _assert_string_or_int( + vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) + if dtype is None: + dtype = vocabulary_dtype + elif dtype.is_integer != vocabulary_dtype.is_integer: + raise ValueError( + 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( + dtype, vocabulary_dtype, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + + return _VocabularyListCategoricalColumn( + key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype, + default_value=default_value) + + class _FeatureColumn(object): """Represents a feature column abstraction. @@ -1170,11 +1251,9 @@ class _HashedCategoricalColumn( if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') - if (input_tensor.dtype != dtypes.string and - not input_tensor.dtype.is_integer): - raise ValueError('input tensors dtype must be string or integer. ' - 'dtype: {}, column_name: {}'.format( - input_tensor.dtype, self.key)) + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( @@ -1202,8 +1281,9 @@ class _HashedCategoricalColumn( return _CategoricalColumn.IdWeightPair(inputs.get(self), None) -class _VocabularyCategoricalColumn( - _CategoricalColumn, collections.namedtuple('_VocabularyCategoricalColumn', ( +class _VocabularyFileCategoricalColumn( + _CategoricalColumn, + collections.namedtuple('_VocabularyFileCategoricalColumn', ( 'key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'dtype', 'default_value' ))): @@ -1226,15 +1306,15 @@ class _VocabularyCategoricalColumn( 'key: {}, column dtype: {}, tensor dtype: {}'.format( self.key, self.dtype, input_tensor.dtype)) + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + key_dtype = self.dtype if input_tensor.dtype.is_integer: # `index_table_from_file` requires 64-bit integer keys. key_dtype = dtypes.int64 input_tensor = math_ops.to_int64(input_tensor) - elif input_tensor.dtype != dtypes.string: - raise ValueError('input tensors dtype must be string or integer. ' - 'dtype: {}, column_name: {}'.format( - input_tensor.dtype, self.key)) return lookup_ops.index_table_from_file( vocabulary_file=self.vocabulary_file, @@ -1254,6 +1334,56 @@ class _VocabularyCategoricalColumn( return _CategoricalColumn.IdWeightPair(inputs.get(self), None) +class _VocabularyListCategoricalColumn( + _CategoricalColumn, + collections.namedtuple('_VocabularyListCategoricalColumn', ( + 'key', 'vocabulary_list', 'dtype', 'default_value' + ))): + """See `categorical_column_with_vocabulary_list`.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_config(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input(inputs.get(self.key)) + + if self.dtype.is_integer != input_tensor.dtype.is_integer: + raise ValueError( + 'Column dtype and SparseTensors dtype must be compatible. ' + 'key: {}, column dtype: {}, tensor dtype: {}'.format( + self.key, self.dtype, input_tensor.dtype)) + + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + + key_dtype = self.dtype + if input_tensor.dtype.is_integer: + # `index_table_from_tensor` requires 64-bit integer keys. + key_dtype = dtypes.int64 + input_tensor = math_ops.to_int64(input_tensor) + + return lookup_ops.index_table_from_tensor( + mapping=tuple(self.vocabulary_list), + default_value=self.default_value, + dtype=key_dtype, + name='{}_lookup'.format(self.key)).lookup(input_tensor) + + @property + def _num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return len(self.vocabulary_list) + + def _get_sparse_tensors( + self, inputs, weight_collections=None, trainable=None): + return _CategoricalColumn.IdWeightPair(inputs.get(self), None) + + # TODO(zakaria): Move this to embedding_ops and make it public. def _safe_embedding_lookup_sparse(embedding_weights, sparse_ids, diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index ad67a082dc9..59aa39411f5 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -1193,10 +1193,22 @@ class MakeInputLayerTest(test.TestCase): self.assertAllClose([[1., 3.]], net2.eval()) -class VocabularyCategoricalColumnTest(test.TestCase): +def _assert_sparse_tensor_value(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class VocabularyFileCategoricalColumnTest(test.TestCase): def setUp(self): - super(VocabularyCategoricalColumnTest, self).setUp() + super(VocabularyFileCategoricalColumnTest, self).setUp() # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22 self._warriors_vocabulary_file_name = test.test_src_dir_path( @@ -1208,17 +1220,6 @@ class VocabularyCategoricalColumnTest(test.TestCase): 'python/feature_column/testdata/wire_vocabulary.txt') self._wire_vocabulary_size = 3 - def _assert_sparse_tensor_value(self, expected, actual): - self.assertEqual(np.int64, np.array(actual.indices).dtype) - self.assertAllEqual(expected.indices, actual.indices) - - self.assertEqual( - np.array(expected.values).dtype, np.array(actual.values).dtype) - self.assertAllEqual(expected.values, actual.values) - - self.assertEqual(np.int64, np.array(actual.dense_shape).dtype) - self.assertAllEqual(expected.dense_shape, actual.dense_shape) - def test_defaults(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path_to_file', vocabulary_size=3) @@ -1316,7 +1317,7 @@ class VocabularyCategoricalColumnTest(test.TestCase): num_oov_buckets=-1) def test_invalid_dtype(self): - with self.assertRaisesRegexp(ValueError, 'Invalid dtype'): + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path', vocabulary_size=3, dtype=dtypes.float64) @@ -1331,6 +1332,36 @@ class VocabularyCategoricalColumnTest(test.TestCase): num_oov_buckets=100, default_value=2) + def test_invalid_input_dtype_int32(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + dtype=dtypes.string) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(12, 24, 36), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_invalid_input_dtype_string(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + def test_get_sparse_tensors(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', @@ -1346,7 +1377,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, -1, 0), dtype=np.int64), @@ -1365,7 +1397,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (1, 1)), values=np.array((2, -1, 0), dtype=np.int64), @@ -1388,7 +1421,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, 2, 0), dtype=np.int64), @@ -1411,7 +1445,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, 33, 0, 62), dtype=np.int64), @@ -1436,7 +1471,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((-1, -1, 0), dtype=np.int64), @@ -1459,7 +1495,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, -1, 0, 4), dtype=np.int64), @@ -1481,7 +1518,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (1, 1), (2, 2)), values=np.array((2, default_value, 0, 4), dtype=np.int64), @@ -1505,7 +1543,8 @@ class VocabularyCategoricalColumnTest(test.TestCase): # pylint: enable=protected-access self.assertIsNone(id_weight_pair.weight_tensor) with _initialized_session(): - self._assert_sparse_tensor_value( + _assert_sparse_tensor_value( + self, sparse_tensor.SparseTensorValue( indices=inputs.indices, values=np.array((2, 60, 0, 4), dtype=np.int64), @@ -1538,5 +1577,256 @@ class VocabularyCategoricalColumnTest(test.TestCase): self.assertAllClose(((3.,), (5.,)), predictions.eval()) +class VocabularyListCategoricalColumnTest(test.TestCase): + + def test_defaults_string(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.string) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_defaults_int(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36)) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_all_constructor_args(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32, + default_value=-99) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_deep_copy(self): + """Tests deepcopy of categorical_column_with_hash_bucket.""" + original = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_invalid_dtype(self): + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'), + dtype=dtypes.float32) + + def test_invalid_mapping_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary dtype must be string or integer'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12., 24., 36.)) + + def test_mismatched_int_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'dtype.*and vocabulary dtype.*do not match'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'), + dtype=dtypes.int32) + + def test_mismatched_string_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'dtype.*and vocabulary dtype.*do not match'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string) + + def test_none_mapping(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary_list.*must be non-empty'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=None) + + def test_empty_mapping(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary_list.*must be non-empty'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=tuple([])) + + def test_duplicate_mapping(self): + with self.assertRaisesRegexp(ValueError, 'Duplicate keys'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 12)) + + def test_invalid_input_dtype_int32(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(12, 24, 36), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_invalid_input_dtype_string(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=(12, 24, 36)) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': (('marlo', ''), ('skywalker', 'omar')) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_default_value_in_vocabulary(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo'), + default_value=2) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 2, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((11, 100, 30, 22), dtype=np.int32), + dense_shape=(3, 3)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_dense_input(self): + default_value = -100 + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32, + default_value=default_value) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': np.array( + ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), + dtype=np.int32) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((2, default_value, 0, 4), dtype=np.int64), + dense_shape=(3, 3)), + id_weight_pair.id_tensor.eval()) + + def test_make_linear_model(self): + wire_column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + self.assertEqual(3, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> None, 'omar' -> 0: wire_var[0] = 1 + self.assertAllClose(((3.,), (1.,)), predictions.eval()) + + if __name__ == '__main__': test.main() From 44cf98028b635ff3dd4145df263b0706ba663924 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Thu, 4 May 2017 19:11:23 -0800 Subject: [PATCH 41/43] RNN checkpoint migration tool Change: 155158477 --- tensorflow/contrib/rnn/BUILD | 25 ++ .../rnn/python/tools/checkpoint_convert.py | 231 ++++++++++++++++++ .../python/tools/checkpoint_convert_test.py | 108 ++++++++ .../tools/pip_package/pip_smoke_test.py | 2 +- 4 files changed, 365 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/rnn/python/tools/checkpoint_convert.py create mode 100644 tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index ab443eab6f6..9d67563eddd 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -304,6 +304,7 @@ filegroup( exclude = [ "**/METADATA", "**/OWNERS", + "tools/**", ], ), visibility = ["//tensorflow:__subpackages__"], @@ -351,3 +352,27 @@ tf_kernel_library( "//third_party/eigen3", ], ) + +py_binary( + name = "checkpoint_convert", + srcs = ["python/tools/checkpoint_convert.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], +) + +py_test( + name = "checkpoint_convert_test", + size = "small", + srcs = ["python/tools/checkpoint_convert_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":checkpoint_convert", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py new file mode 100644 index 00000000000..1e29114b0cc --- /dev/null +++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py @@ -0,0 +1,231 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Convert checkpoints using RNNCells to new name convention. + +Usage: + + python checkpoint_convert [--write_v1_checkpoint] \ + '/path/to/checkpoint' '/path/to/new_checkpoint' +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import collections +import re +import sys + +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import app +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib + +_RNN_NAME_REPLACEMENTS = collections.OrderedDict([ + ############################################################################ + # contrib/rnn/python/ops/core_rnn_cell_impl.py + # BasicRNNCell + ('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'), + ('basic_rnn_cell/biases', 'basic_rnn_cell/bias'), + # GRUCell + ('gru_cell/weights', 'gru_cell/kernel'), + ('gru_cell/biases', 'gru_cell/bias'), + ('gru_cell/gates/weights', 'gru_cell/gates/kernel'), + ('gru_cell/gates/biases', 'gru_cell/gates/bias'), + ('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'), + ('gru_cell/candidate/biases', 'gru_cell/candidate/bias'), + # BasicLSTMCell + ('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'), + ('basic_lstm_cell/biases', 'basic_lstm_cell/bias'), + # LSTMCell + ('lstm_cell/weights', 'lstm_cell/kernel'), + ('lstm_cell/biases', 'lstm_cell/bias'), + ('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'), + ('lstm_cell/projection/biases', 'lstm_cell/projection/bias'), + # OutputProjectionWrapper + ('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'), + ('output_projection_wrapper/biases', 'output_projection_wrapper/bias'), + # InputProjectionWrapper + ('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'), + ('input_projection_wrapper/biases', 'input_projection_wrapper/bias'), + ############################################################################ + # contrib/rnn/python/ops/lstm_ops.py + # LSTMBlockFusedCell ?? + ('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'), + ('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'), + ############################################################################ + # contrib/rnn/python/ops/rnn_cell.py + # LayerNormBasicLSTMCell + ('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'), + ('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'), + # UGRNNCell, not found in g3, but still need it? + ('ugrnn_cell/weights', 'ugrnn_cell/kernel'), + ('ugrnn_cell/biases', 'ugrnn_cell/bias'), + # NASCell + ('nas_rnn/weights', 'nas_rnn/kernel'), + ('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'), + # IntersectionRNNCell + ('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'), + ('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'), + ('intersection_rnn_cell/in_projection/weights', + 'intersection_rnn_cell/in_projection/kernel'), + ('intersection_rnn_cell/in_projection/biases', + 'intersection_rnn_cell/in_projection/bias'), + # PhasedLSTMCell + ('phased_lstm_cell/mask_gates/weights', + 'phased_lstm_cell/mask_gates/kernel'), + ('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'), + ('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'), + ('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'), + ('phased_lstm_cell/output_gate/weights', + 'phased_lstm_cell/output_gate/kernel'), + ('phased_lstm_cell/output_gate/biases', + 'phased_lstm_cell/output_gate/bias'), + # AttentionCellWrapper + ('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'), + ('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'), + ('attention_cell_wrapper/attn_output_projection/weights', + 'attention_cell_wrapper/attn_output_projection/kernel'), + ('attention_cell_wrapper/attn_output_projection/biases', + 'attention_cell_wrapper/attn_output_projection/bias'), + ('attention_cell_wrapper/attention/weights', + 'attention_cell_wrapper/attention/kernel'), + ('attention_cell_wrapper/attention/biases', + 'attention_cell_wrapper/attention/bias'), +]) + +_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([ + ('LSTMCell/W_', 'lstm_cell/weights/part_'), + ('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'), + ('GRUCell/W_', 'gru_cell/weights/part_'), + ('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'), +]) + + +def _rnn_name_replacement(var_name): + for pattern in _RNN_NAME_REPLACEMENTS: + if pattern in var_name: + old_var_name = var_name + var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern]) + logging.info('Converted: %s --> %s' % (old_var_name, var_name)) + break + return var_name + + +def _rnn_name_replacement_sharded(var_name): + for pattern in _RNN_SHARDED_NAME_REPLACEMENTS: + if pattern in var_name: + old_var_name = var_name + var_name = var_name.replace(pattern, + _RNN_SHARDED_NAME_REPLACEMENTS[pattern]) + logging.info('Converted: %s --> %s' % (old_var_name, var_name)) + return var_name + + +def _split_sharded_vars(name_shape_map): + """Split shareded variables. + + Args: + name_shape_map: A dict from variable name to variable shape. + + Returns: + not_sharded: Names of the non-sharded variables. + sharded: Names of the sharded varibales. + """ + sharded = [] + not_sharded = [] + for name in name_shape_map: + if re.match(name, '_[0-9]+$'): + if re.sub('_[0-9]+$', '_1', name) in name_shape_map: + sharded.append(name) + else: + not_sharded.append(name) + else: + not_sharded.append(name) + return not_sharded, sharded + + +def convert_names(checkpoint_from_path, + checkpoint_to_path, + write_v1_checkpoint=False): + """Migrates the names of variables within a checkpoint. + + Args: + checkpoint_from_path: Path to source checkpoint to be read in. + checkpoint_to_path: Path to checkpoint to be written out. + write_v1_checkpoint: Whether the output checkpoint will be in V1 format. + + Returns: + A dictionary that maps the new variable names to the Variable objects. + A dictionary that maps the old variable names to the new variable names. + """ + with ops.Graph().as_default(): + logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path) + reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path) + name_shape_map = reader.get_variable_to_shape_map() + not_sharded, sharded = _split_sharded_vars(name_shape_map) + new_variable_map = {} + conversion_map = {} + for var_name in not_sharded: + new_var_name = _rnn_name_replacement(var_name) + tensor = reader.get_tensor(var_name) + var = variables.Variable(tensor, name=var_name) + new_variable_map[new_var_name] = var + if new_var_name != var_name: + conversion_map[var_name] = new_var_name + for var_name in sharded: + new_var_name = _rnn_name_replacement_sharded(var_name) + var = variables.Variable(tensor, name=var_name) + new_variable_map[new_var_name] = var + if new_var_name != var_name: + conversion_map[var_name] = new_var_name + + write_version = (saver_pb2.SaverDef.V1 + if write_v1_checkpoint else saver_pb2.SaverDef.V2) + saver = saver_lib.Saver(new_variable_map, write_version=write_version) + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path) + saver.save(sess, checkpoint_to_path) + + logging.info('Summary:') + logging.info(' Converted %d variable name(s).' % len(new_variable_map)) + return new_variable_map, conversion_map + + +def main(_): + convert_names( + FLAGS.checkpoint_from_path, + FLAGS.checkpoint_to_path, + write_v1_checkpoint=FLAGS.write_v1_checkpoint) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument('checkpoint_from_path', type=str, + help='Path to source checkpoint to be read in.') + parser.add_argument('checkpoint_to_path', type=str, + help='Path to checkpoint to be written out.') + parser.add_argument('--write_v1_checkpoint', action='store_true', + help='Write v1 checkpoint') + FLAGS, unparsed = parser.parse_known_args() + + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py new file mode 100644 index 00000000000..e2fc2fa80ea --- /dev/null +++ b/tensorflow/contrib/rnn/python/tools/checkpoint_convert_test.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for checkpoint converter.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import os +import tempfile + +from tensorflow.contrib.rnn.python.tools import checkpoint_convert +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class CheckpointConvertTest(test.TestCase): + + def setUp(self): + self._old_ckpt_path = tempfile.mktemp() + self._new_ckpt_path = tempfile.mktemp() + ops.reset_default_graph() + + def tearDown(self): + for file_name in glob.glob(self._old_ckpt_path + "*"): + os.remove(file_name) + for file_name in glob.glob(self._new_ckpt_path + "*"): + os.remove(file_name) + + def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self): + for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS: + new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name] + self.assertTrue(old_name) + self.assertTrue(new_name) + self.assertNotEqual(old_name, new_name) + for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS: + new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name] + self.assertTrue(old_name) + self.assertTrue(new_name) + self.assertNotEqual(old_name, new_name) + + def testConversionFromV2WithConvertedVariableNamesSucceeds(self): + variables.Variable(10.0, name="a") + for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS: + variables.Variable(20.0, name=old_name) + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path) + self.assertTrue(glob.glob(self._new_ckpt_path + "*")) + self.assertItemsEqual( + ["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()), + new_var_map.keys()) + self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map) + + def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self): + variables.Variable(10.0, name="a") + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path) + self.assertItemsEqual(["a"], new_var_map.keys()) + self.assertFalse(conversion_map) + + def testConversionToV1Succeeds(self): + variables.Variable(10.0, name="a") + variables.Variable( + 20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]) + + with session.Session() as sess: + saver = saver_lib.Saver() + sess.run(variables.global_variables_initializer()) + saver.save(sess, self._old_ckpt_path) + + new_var_map, conversion_map = checkpoint_convert.convert_names( + self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True) + self.assertItemsEqual( + ["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]], + new_var_map.keys()) + self.assertEqual( + {list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]: + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]}, + conversion_map) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 4bb5c1b73c6..459d6ee3284 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -57,7 +57,7 @@ BLACKLIST = [ "//tensorflow/contrib/factorization/examples:mnist.py", "//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long "//tensorflow/contrib/bayesflow:reinforce_simple_example", - "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py" # pylint:disable=line-too-long + "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long ] From efa08d80a53a95ce6b8beb61ec86a275aed6b6c7 Mon Sep 17 00:00:00 2001 From: Andrew Harp Date: Thu, 4 May 2017 19:30:36 -0800 Subject: [PATCH 42/43] Android demo: Allow DetectorActivity to gracefully degrade if no native ObjectTracker support is found. If libtensorflow_demo.so is not found in the APK, rendered boxes will simply be stationary and will be replaced whenever new results come in. Partially addresses #6385 Change: 155159326 --- .../org/tensorflow/demo/DetectorActivity.java | 2 +- .../demo/tracking/MultiBoxTracker.java | 54 +++++++++++++++---- .../demo/tracking/ObjectTracker.java | 26 ++++++--- 3 files changed, 64 insertions(+), 18 deletions(-) diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java index cdb6c3fed80..5800f80651b 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java @@ -124,7 +124,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE); - tracker = new MultiBoxTracker(getResources().getDisplayMetrics()); + tracker = new MultiBoxTracker(this); if (USE_YOLO) { detector = diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java index 49c91d600da..91d1f9feb18 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java @@ -15,6 +15,7 @@ limitations under the License. package org.tensorflow.demo.tracking; +import android.content.Context; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Matrix; @@ -24,9 +25,9 @@ import android.graphics.Paint.Join; import android.graphics.Paint.Style; import android.graphics.RectF; import android.text.TextUtils; -import android.util.DisplayMetrics; import android.util.Pair; import android.util.TypedValue; +import android.widget.Toast; import java.util.LinkedList; import java.util.List; import java.util.Queue; @@ -69,6 +70,7 @@ public class MultiBoxTracker { private static class TrackedRecognition { ObjectTracker.TrackedObject trackedObject; + RectF location; float detectionConfidence; int color; String title; @@ -87,8 +89,10 @@ public class MultiBoxTracker { private int frameHeight; private int sensorOrientation; + private Context context; - public MultiBoxTracker(final DisplayMetrics metrics) { + public MultiBoxTracker(final Context context) { + this.context = context; for (final int color : COLORS) { availableColors.add(color); } @@ -100,7 +104,9 @@ public class MultiBoxTracker { boxPaint.setStrokeJoin(Join.ROUND); boxPaint.setStrokeMiter(100); - textSizePx = TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, metrics); + textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, context.getResources().getDisplayMetrics()); borderedText = new BorderedText(textSizePx); } @@ -152,10 +158,6 @@ public class MultiBoxTracker { } public synchronized void draw(final Canvas canvas) { - if (objectTracker == null) { - return; - } - // TODO(andrewharp): This may not work for non-90 deg rotations. final float multiplier = Math.min(canvas.getWidth() / (float) frameHeight, canvas.getHeight() / (float) frameWidth); @@ -168,9 +170,11 @@ public class MultiBoxTracker { sensorOrientation, false); for (final TrackedRecognition recognition : trackedObjects) { - final ObjectTracker.TrackedObject trackedObject = recognition.trackedObject; + final RectF trackedPos = + (objectTracker != null) + ? recognition.trackedObject.getTrackedPositionInPreviewFrame() + : new RectF(recognition.location); - final RectF trackedPos = trackedObject.getTrackedPositionInPreviewFrame(); getFrameToCanvasMatrix().mapRect(trackedPos); boxPaint.setColor(recognition.color); @@ -185,6 +189,8 @@ public class MultiBoxTracker { } } + private boolean initialized = false; + public synchronized void onFrame( final int w, final int h, @@ -192,7 +198,7 @@ public class MultiBoxTracker { final int sensorOrienation, final byte[] frame, final long timestamp) { - if (objectTracker == null) { + if (objectTracker == null && !initialized) { ObjectTracker.clearInstance(); logger.i("Initializing ObjectTracker: %dx%d", w, h); @@ -200,6 +206,19 @@ public class MultiBoxTracker { frameWidth = w; frameHeight = h; this.sensorOrientation = sensorOrienation; + initialized = true; + + if (objectTracker == null) { + String message = + "Object tracking support not found. " + + "See tensorflow/examples/android/README.md for details."; + Toast.makeText(context, message, Toast.LENGTH_LONG).show(); + logger.e(message); + } + } + + if (objectTracker == null) { + return; } objectTracker.nextFrame(frame, null, timestamp, null, true); @@ -255,7 +274,20 @@ public class MultiBoxTracker { } if (objectTracker == null) { - logger.w("No ObjectTracker, can't track anything!"); + trackedObjects.clear(); + for (final Pair potential : rectsToTrack) { + final TrackedRecognition trackedRecognition = new TrackedRecognition(); + trackedRecognition.detectionConfidence = potential.first; + trackedRecognition.location = new RectF(potential.second.getLocation()); + trackedRecognition.trackedObject = null; + trackedRecognition.title = potential.second.getTitle(); + trackedRecognition.color = COLORS[trackedObjects.size()]; + trackedObjects.add(trackedRecognition); + + if (trackedObjects.size() >= COLORS.length) { + break; + } + } return; } diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java index 82de634baff..69f202b5681 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java @@ -48,7 +48,18 @@ import org.tensorflow.demo.env.Size; * ObjectTracker still exists. */ public class ObjectTracker { - private final Logger logger = new Logger(); + private static final Logger LOGGER = new Logger(); + + private static boolean libraryFound = false; + + static { + try { + System.loadLibrary("tensorflow_demo"); + libraryFound = true; + } catch (UnsatisfiedLinkError e) { + LOGGER.e("libtensorflow_demo.so not found, tracking unavailable"); + } + } private static final boolean DRAW_TEXT = false; @@ -194,6 +205,13 @@ public class ObjectTracker { public static synchronized ObjectTracker getInstance( final int frameWidth, final int frameHeight, final int rowStride, final boolean alwaysTrack) { + if (!libraryFound) { + LOGGER.e( + "Native object tracking support not found. " + + "See tensorflow/examples/android/README.md for details."); + return null; + } + if (instance == null) { instance = new ObjectTracker(frameWidth, frameHeight, rowStride, alwaysTrack); instance.init(); @@ -519,7 +537,7 @@ public class ObjectTracker { checkValidObject(); synchronized (ObjectTracker.this) { if (lastExternalPositionTime > timestamp) { - logger.w("Tried to use older position time!"); + LOGGER.w("Tried to use older position time!"); return; } final RectF externalPosition = downscaleRect(position); @@ -640,8 +658,4 @@ public class ObjectTracker { protected static native void downsampleImageNative( int width, int height, int rowStride, byte[] input, int factor, byte[] output); - - static { - System.loadLibrary("tensorflow_demo"); - } } From f28935a7d280b6ba75fe93fe35783d87b9cc2ec9 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Thu, 4 May 2017 19:43:48 -0800 Subject: [PATCH 43/43] Implement ClusterSpec Propagation in TF Master ClusterSpec propagation is a capability upgrade for TensorFlow that should make it much easier to (1) build distributed TensorFlow clusters, and (2) handle node failures. The ClusterSpec propagation capability allows TensorFlow workers to be booted independently of each other, and with no knowledge about others. The client can then construct a ClusterDef (ClusterSpec), and then send it to the TF master at session creation. The master in turn then propagates the ClusterDef along to all of the workers. Change: 155159972 --- tensorflow/compiler/jit/xla_device.cc | 2 +- .../compiler/tf2xla/xla_compilation_device.cc | 3 +- .../contrib/cmake/tf_core_framework.cmake | 1 + .../makefile/proto_text_pb_cc_files.txt | 1 + .../makefile/proto_text_pb_h_files.txt | 1 + .../contrib/makefile/tf_pb_text_files.txt | 1 + .../contrib/makefile/tf_proto_files.txt | 1 + tensorflow/core/BUILD | 1 + tensorflow/core/common_runtime/device.cc | 3 +- tensorflow/core/common_runtime/device.h | 3 +- tensorflow/core/common_runtime/device_mgr.cc | 19 +- tensorflow/core/common_runtime/device_mgr.h | 2 + tensorflow/core/common_runtime/device_set.h | 5 +- .../core/common_runtime/device_set_test.cc | 3 +- .../core/common_runtime/gpu/gpu_device.cc | 7 +- .../core/common_runtime/local_device.cc | 6 +- tensorflow/core/common_runtime/local_device.h | 4 +- .../core/common_runtime/renamed_device.cc | 54 ++++ .../core/common_runtime/renamed_device.h | 119 ++++++++ .../core/common_runtime/simple_placer_test.cc | 2 +- .../core/common_runtime/threadpool_device.cc | 6 +- tensorflow/core/distributed_runtime/BUILD | 4 +- .../base_rendezvous_mgr.cc | 102 +++++-- .../distributed_runtime/base_rendezvous_mgr.h | 43 ++- .../core/distributed_runtime/graph_mgr.cc | 37 +-- .../core/distributed_runtime/graph_mgr.h | 11 +- tensorflow/core/distributed_runtime/master.cc | 121 +++++++- .../core/distributed_runtime/master_env.h | 34 ++- .../distributed_runtime/master_session.cc | 151 ++++++++-- .../core/distributed_runtime/master_session.h | 21 +- .../distributed_runtime/message_wrappers.cc | 21 ++ .../distributed_runtime/message_wrappers.h | 11 + .../core/distributed_runtime/remote_device.cc | 49 +++- .../rendezvous_mgr_interface.h | 22 +- .../rpc/grpc_server_lib.cc | 83 +++--- .../distributed_runtime/rpc/grpc_server_lib.h | 15 +- .../distributed_runtime/rpc/grpc_session.cc | 7 +- .../rpc/grpc_worker_service.cc | 14 +- .../rpc/rpc_rendezvous_mgr.cc | 100 ++----- .../rpc/rpc_rendezvous_mgr.h | 13 +- .../rpc/rpc_rendezvous_mgr_test.cc | 19 +- .../core/distributed_runtime/session_mgr.cc | 108 ++----- .../core/distributed_runtime/session_mgr.h | 44 ++- .../distributed_runtime/session_mgr_test.cc | 81 +----- tensorflow/core/distributed_runtime/worker.cc | 32 +-- .../core/distributed_runtime/worker_env.h | 11 + .../distributed_runtime/worker_interface.h | 5 + .../distributed_runtime/worker_session.cc | 84 +++++- .../core/distributed_runtime/worker_session.h | 12 +- tensorflow/core/framework/device_base.h | 8 +- tensorflow/core/protobuf/cluster.proto | 82 ++++++ tensorflow/core/protobuf/config.proto | 6 + tensorflow/core/protobuf/master.proto | 3 + .../core/protobuf/tensorflow_server.proto | 64 +---- tensorflow/core/protobuf/worker.proto | 12 + tensorflow/python/__init__.py | 2 + tensorflow/python/client/session_test.py | 267 +++++++++++++++++- tensorflow/python/training/server_lib.py | 7 +- tensorflow/python/training/training.py | 30 +- .../api/golden/tensorflow.-config-proto.pbtxt | 4 + .../tensorflow.train.-cluster-def.pbtxt | 2 +- ...nsorflow.train.-job-def.-tasks-entry.pbtxt | 2 +- .../golden/tensorflow.train.-job-def.pbtxt | 2 +- 63 files changed, 1396 insertions(+), 594 deletions(-) create mode 100644 tensorflow/core/common_runtime/renamed_device.cc create mode 100644 tensorflow/core/common_runtime/renamed_device.h create mode 100644 tensorflow/core/protobuf/cluster.proto diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 93f487c36ca..5e336c5287b 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options, const DeviceType& jit_device_name, perftools::gputools::Platform* platform, Allocator* xla_allocator) - : LocalDevice(options, attrs, xla_allocator), + : LocalDevice(options, attrs), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(xla_allocator), diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d86e741b69e..362a1018955 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, options, Device::BuildDeviceAttributes( "", type, Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type())), - cpu_allocator()), + strings::StrCat("device: XLA compilation device ", type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 6fd1ae08149..560e45fc135 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -118,6 +118,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/types.proto" "tensorflow/core/framework/versions.proto" "tensorflow/core/lib/core/error_codes.proto" + "tensorflow/core/protobuf/cluster.proto" "tensorflow/core/protobuf/config.proto" "tensorflow/core/protobuf/debug.proto" "tensorflow/core/protobuf/rewriter_config.proto" diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index c0969e6dee2..2f1fcb149e1 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/meta_graph.pb.cc +tensorflow/core/protobuf/cluster.pb.cc tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/debug.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 132b4775962..6087a45168d 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/meta_graph.pb.h +tensorflow/core/protobuf/cluster.pb.h tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/debug.pb.h tensorflow/core/protobuf/rewriter_config.pb.h diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index f1da05e4c6e..c39257ffa91 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -1,6 +1,7 @@ tensorflow/core/util/saved_tensor_slice.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc +tensorflow/core/protobuf/cluster.pb_text.cc tensorflow/core/protobuf/config.pb_text.cc tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 2a78ea61016..5eadf5d55b6 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/meta_graph.proto +tensorflow/core/protobuf/cluster.proto tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/debug.proto tensorflow/core/protobuf/rewriter_config.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 435618ace7a..9d0c6a6c3eb 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [ "framework/versions.proto", "lib/core/error_codes.proto", "protobuf/config.proto", + "protobuf/cluster.proto", "protobuf/debug.proto", "protobuf/queue_runner.proto", "protobuf/rewriter_config.proto", diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc index 78649afeb93..aa8a2d989bf 100644 --- a/tensorflow/core/common_runtime/device.cc +++ b/tensorflow/core/common_runtime/device.cc @@ -23,8 +23,7 @@ limitations under the License. namespace tensorflow { -Device::Device(Env* env, const DeviceAttributes& device_attributes, - Allocator* device_allocator) +Device::Device(Env* env, const DeviceAttributes& device_attributes) : DeviceBase(env), device_attributes_(device_attributes) { CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_)) << "Invalid device name: " << name(); diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 07c6bdd6831..c0e58f143e3 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -53,8 +53,7 @@ namespace tensorflow { class Device : public DeviceBase { public: - Device(Env* env, const DeviceAttributes& device_attributes, - Allocator* device_allocator); + Device(Env* env, const DeviceAttributes& device_attributes); ~Device() override; // Full name of this device (see top comment). diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index 7807656cb25..31f12d48337 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector& devices) for (Device* d : devices) { devices_.push_back(d); - // Register under both the full name and the local name. + // Register under the (1) full name, (2) canonical name, and (3) local name. string full_name = d->name(); device_map_[CopyToBackingStore(full_name)] = d; + DeviceNameUtils::ParsedName parsed_name = d->parsed_name(); + if (parsed_name.has_job && parsed_name.has_replica && + parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) { + string canonical_name = DeviceNameUtils::FullName( + parsed_name.job, parsed_name.replica, parsed_name.task, + parsed_name.type, parsed_name.id); + device_map_[CopyToBackingStore(canonical_name)] = d; + } string lname = DeviceNameUtils::LocalName(d->name()); device_map_[CopyToBackingStore(lname)] = d; device_type_counts_[d->device_type()]++; @@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector& devices) } DeviceMgr::~DeviceMgr() { - for (auto p : devices_) delete p; + // TODO(b/37437134): Remove destructor after converting to std::unique_ptr. + for (Device* p : devices_) delete p; } StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) { @@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const { Status s; auto iter = device_map_.find(name); if (iter == device_map_.end()) { + std::vector device_names; + for (auto&& itr : device_map_) { + device_names.push_back(itr.first); + } + LOG(WARNING) << "Unknown device: " << name + << " all devices: " << str_util::Join(device_names, ", "); return errors::InvalidArgument(name, " unknown device."); } *device = iter->second; diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index bb1ed726408..d16681ac59d 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -36,6 +36,7 @@ class DeviceMgr { public: // Takes ownership of each device in 'devices'. // TODO(zhifengc): Other initialization information. + // TODO(b/37437134): Use std::unique_ptr's to track ownership. explicit DeviceMgr(const std::vector& devices); ~DeviceMgr(); @@ -61,6 +62,7 @@ class DeviceMgr { int NumDeviceType(const string& type) const; private: + // TODO(b/37437134): Use std::unique_ptr's to track ownership. typedef gtl::InlinedVector DeviceVec; DeviceVec devices_; diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h index b0540dfa95b..4cd56e583c0 100644 --- a/tensorflow/core/common_runtime/device_set.h +++ b/tensorflow/core/common_runtime/device_set.h @@ -39,7 +39,10 @@ class DeviceSet { // Set the device designated as the "client". This device // must also be registered via AddDevice(). - void set_client_device(Device* device) { client_device_ = device; } + void set_client_device(Device* device) { + DCHECK(client_device_ == nullptr); + client_device_ = device; + } // Returns a pointer to the device designated as the "client". Device* client_device() const { return client_device_; } diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc index ff20ee94a7d..0507076c8c3 100644 --- a/tensorflow/core/common_runtime/device_set_test.cc +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -27,8 +27,7 @@ namespace { static Device* Dev(const char* type, const char* name) { class FakeDevice : public Device { public: - explicit FakeDevice(const DeviceAttributes& attr) - : Device(nullptr, attr, nullptr) {} + explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 0e2343cfe3f..02f70d835d5 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name, int gpu_id, const string& physical_device_desc, Allocator* gpu_allocator, Allocator* cpu_allocator, bool sync_every_op, int32 max_streams) - : LocalDevice(options, - Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit, - locality, physical_device_desc), - gpu_allocator), + : LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU, + memory_limit, locality, + physical_device_desc)), gpu_allocator_(gpu_allocator), cpu_allocator_(cpu_allocator), gpu_id_(gpu_id), diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc index 0a6342ed736..3f7c9f68dba 100644 --- a/tensorflow/core/common_runtime/local_device.cc +++ b/tensorflow/core/common_runtime/local_device.cc @@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo { }; LocalDevice::LocalDevice(const SessionOptions& options, - const DeviceAttributes& attributes, - Allocator* device_allocator) - : Device(options.env, attributes, device_allocator), - owned_tp_info_(nullptr) { + const DeviceAttributes& attributes) + : Device(options.env, attributes), owned_tp_info_(nullptr) { // If we're running on the CPU, log warnings if we're not compiled using the // best flags for performance. port::WarnAboutUnusedCPUFeatures(); diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h index d1c27c62481..84a4f66db4a 100644 --- a/tensorflow/core/common_runtime/local_device.h +++ b/tensorflow/core/common_runtime/local_device.h @@ -33,8 +33,8 @@ struct SessionOptions; // GPUDevice into more 'process-wide' abstractions. class LocalDevice : public Device { public: - LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes, - Allocator* device_allocator); + LocalDevice(const SessionOptions& options, + const DeviceAttributes& attributes); ~LocalDevice() override; private: diff --git a/tensorflow/core/common_runtime/renamed_device.cc b/tensorflow/core/common_runtime/renamed_device.cc new file mode 100644 index 00000000000..fa9713735ed --- /dev/null +++ b/tensorflow/core/common_runtime/renamed_device.cc @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/renamed_device.h" + +namespace tensorflow { + +// TODO(saeta): Convert to returning a std::unique_ptr? +/* static */ +Device* RenamedDevice::NewRenamedDevice(const string& new_base, + Device* underlying, + bool owns_underlying) { + DeviceNameUtils::ParsedName parsed_name; + CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name)); + DeviceNameUtils::ParsedName underlying_parsed_name = + underlying->parsed_name(); + CHECK(underlying_parsed_name.has_type); + CHECK(underlying_parsed_name.has_id); + parsed_name.type = underlying_parsed_name.type; + parsed_name.id = underlying_parsed_name.id; + string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica, + parsed_name.task, parsed_name.type, + parsed_name.id); + DeviceAttributes attributes(underlying->attributes()); + attributes.set_name(name); + return new RenamedDevice(underlying, attributes, owns_underlying); +} + +RenamedDevice::RenamedDevice(Device* underlying, + const DeviceAttributes& attributes, + bool owns_underlying) + : Device(underlying->env(), attributes), + underlying_(underlying), + owns_underlying_(owns_underlying) {} + +RenamedDevice::~RenamedDevice() { + if (owns_underlying_) { + delete underlying_; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h new file mode 100644 index 00000000000..0158e18cedc --- /dev/null +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -0,0 +1,119 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// Wraps a device with a new name, delegating work to the wrapped device. +// +// This class is used to wrap local devices when using clusterspec propagation +// where the name of a particular device may change in the context of a given +// session. +class RenamedDevice : public Device { + public: + static Device* NewRenamedDevice(const string& new_base, Device* underlying, + bool owns_underlying); + ~RenamedDevice() override; + + // Below are virtual methods defined on DeviceBase + bool RequiresRecordingAccessedTensors() const override { + return underlying_->RequiresRecordingAccessedTensors(); + } + + const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { + return underlying_->tensorflow_cpu_worker_threads(); + } + + const GpuDeviceInfo* tensorflow_gpu_device_info() const override { + return underlying_->tensorflow_gpu_device_info(); + } + + Allocator* GetAllocator(AllocatorAttributes attr) override { + return underlying_->GetAllocator(attr); + } + + Allocator* GetStepAllocator(AllocatorAttributes attr, + ResourceMgr* step_resource_manager) override { + return underlying_->GetStepAllocator(attr, step_resource_manager); + } + + const Eigen::ThreadPoolDevice* eigen_cpu_device() override { + return underlying_->eigen_cpu_device(); + } + +#ifdef TENSORFLOW_USE_SYCL + const Eigen::SyclDevice* eigen_sycl_device() const override { + return underlying_->eigen_sycl_device(); + } +#endif + + PerOpGpuDevice* MakeGpuDevice() override { + return underlying_->MakeGpuDevice(); + } + + void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, + DeviceContext* dc, Allocator* allocator) override { + underlying_->ReinitializeGpuDevice(context, device, dc, allocator); + } + + Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override { + return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor); + } + + // Below are virtual methods defined on Device + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override { + underlying_->Compute(op_kernel, context); + } + + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override { + underlying_->ComputeAsync(op_kernel, context, std::move(done)); + } + + void ConsumeListOfAccessedTensors( + DeviceContext* context, const TensorReferenceVector& tensors) override { + underlying_->ConsumeListOfAccessedTensors(context, tensors); + } + + Status Sync() override { return underlying_->Sync(); } + + Status MaybeRewriteGraph(const FunctionDefLibrary& library, + std::unique_ptr* graph) override { + return underlying_->MaybeRewriteGraph(library, graph); + } + + Status FillContextMap(const Graph* graph, + DeviceContextMap* device_context_map) override { + return underlying_->FillContextMap(graph, device_context_map); + } + + private: + RenamedDevice(Device* underlying, const DeviceAttributes& attributes, + bool owns_underlying); + Device* const underlying_; + const bool owns_underlying_; +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index bd84417b105..24f27af5f1a 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -66,7 +66,7 @@ class DummyOp : public OpKernel { class FakeDevice : public Device { private: explicit FakeDevice(const DeviceAttributes& device_attributes) - : Device(nullptr, device_attributes, nullptr) {} + : Device(nullptr, device_attributes) {} public: Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 60348e885f5..f5f8aab6946 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, const string& name, Bytes memory_limit, const DeviceLocality& locality, Allocator* allocator) - : LocalDevice(options, - Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit, - locality), - allocator), + : LocalDevice(options, Device::BuildDeviceAttributes( + name, DEVICE_CPU, memory_limit, locality)), allocator_(allocator) {} ThreadPoolDevice::~ThreadPoolDevice() {} diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 0f5eb0cb320..d2a828f39f2 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -77,7 +77,6 @@ cc_library( ], deps = [ ":graph_mgr", - ":rendezvous_mgr_interface", ":worker_cache", "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", @@ -92,9 +91,9 @@ cc_library( deps = [ ":graph_mgr", ":worker_session", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", ], ) @@ -237,6 +236,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:master_proto_cc", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index 5863727f19b..e68aea46ecd 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -35,9 +35,8 @@ limitations under the License. namespace tensorflow { -BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env, - const string& worker_name) - : worker_env_(worker_env), worker_name_(worker_name) {} +BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env) + : worker_env_(worker_env) {} BaseRendezvousMgr::~BaseRendezvousMgr() { for (auto& p : table_) { @@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() { } } -Rendezvous* BaseRendezvousMgr::Find(int64 step_id) { +RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) { return FindOrCreate(step_id); } @@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) { mutex_lock l(mu_); Table::iterator iter = table_.find(step_id); if (iter == table_.end()) { - auto rr = Create(step_id, worker_env_, worker_name_); + auto rr = Create(step_id, worker_env_); iter = table_.insert({step_id, rr}).first; } iter->second->Ref(); @@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() { } } -BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, - const string& worker_name, - int64 step_id, +BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id, bool tolerate_dup_recv) : env_(env), - worker_name_(worker_name), step_id_(step_id), - local_(NewLocalRendezvous(tolerate_dup_recv)) {} + local_(NewLocalRendezvous(tolerate_dup_recv)), + session_(nullptr) {} BaseRemoteRendezvous::~BaseRemoteRendezvous() { CHECK(active_.empty()); @@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name, return device_name.starts_with(worker_name); } +Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { + CHECK_NE(session, nullptr) << "session must not be null!"; + std::vector deferred_calls; + { + mutex_lock l(mu_); + if (session_ != nullptr) { + if (session_->worker_name == session->worker_name) { + LOG(INFO) << "Skipping rendezvous re-initialization."; + return Status::OK(); + } + Status s = errors::Internal( + "Double init! Worker names would have changed from: ", + session_->worker_name, " -> ", session->worker_name); + LOG(WARNING) << s; + return s; + } + session_ = session; + std::swap(deferred_calls, deferred_calls_); + } + for (DeferredCall& call : deferred_calls) { + RecvLocalAsyncInternal(call.parsed, std::move(call.done)); + } + return Status::OK(); +} + +WorkerSession* BaseRemoteRendezvous::session() { + mutex_lock l(mu_); + return session_; +} + +bool BaseRemoteRendezvous::is_initialized() { + mutex_lock l(mu_); + return is_initialized_locked(); +} + Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, const Tensor& val, const bool is_dead) { @@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, { mutex_lock l(mu_); if (!status_.ok()) return status_; - } - if (!IsLocalDevice(worker_name_, parsed.src_device)) { - return errors::InvalidArgument("Invalid rendezvous key (src): ", - parsed.FullKey(), " @ ", worker_name_); + DCHECK(is_initialized_locked()); + if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { + return errors::InvalidArgument( + "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", + session_->worker_name); + } } // Buffers "val" and "device_context" in local_. return local_->Send(parsed, args, val, is_dead); @@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, bool is_src) { + // Cache session pointer to avoid repeatedly taking & releasing the lock + // (e.g. calling session()) + WorkerSession* sess = nullptr; { mutex_lock l(mu_); if (!status_.ok()) return status_; + if (!is_initialized_locked()) { + return errors::Internal("ValidateDevices called before initialization."); + } + sess = session_; } - if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) { + if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) { return errors::InvalidArgument("Invalid rendezvous key (src): ", - parsed.FullKey(), " @ ", worker_name_); + parsed.FullKey(), " @ ", sess->worker_name); } - if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) { + if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) { return errors::InvalidArgument("Invalid rendezvous key (dst): ", - parsed.FullKey(), " @ ", worker_name_); + parsed.FullKey(), " @ ", sess->worker_name); } return Status::OK(); } @@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); + CHECK(is_initialized()) << "RecvAsync called when uninitialized."; Status s = ValidateDevices(parsed, false /*!is_src*/); if (!s.ok()) { done(s, Args(), recv_args, Tensor(), false); @@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, DoneCallback done) { + { + mutex_lock l(mu_); + if (!is_initialized_locked()) { + // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a + // remote worker) before the RunStep (or PartialRunStep) RPC from the + // master arrives. RecvLocalAsync thus buffers the arguments until after + // the RemoteRendezvous is Initialize()'d, when it completes the + // rendezvous logic. At some point after Initialize() is called, a Tensor + // is produced locally that will then be sent in response to the incoming + // RPC. + DeferredCall call(parsed, std::move(done)); + deferred_calls_.push_back(call); + return; + } + } + RecvLocalAsyncInternal(parsed, std::move(done)); +} + +void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed, + DoneCallback done) { Status s = ValidateDevices(parsed, true /* is_src */); if (!s.ok()) { done(s, Args(), Args(), Tensor(), false); @@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) { active_.erase(call); } +BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed, + DoneCallback done) + : parsed(parsed), done(std::move(done)) {} + } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h index 447a75913d6..b252f45fe96 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -59,15 +59,17 @@ class BaseRecvTensorCall; // RendezvousMgr must have keys generated by Rendezvous::CreateKey(). class BaseRendezvousMgr : public RendezvousMgrInterface { public: - explicit BaseRendezvousMgr(const WorkerEnv* worker_env, - const string& worker_name); + explicit BaseRendezvousMgr(const WorkerEnv* worker_env); ~BaseRendezvousMgr() override; // Returns Rendezvous supporting send and recv among workers in the // "step_id". The caller takes ownership of one reference on the // returned Rendezvous instance. - Rendezvous* Find(int64 step_id) override; + // + // Note: the caller must guarantee to eventually call Initialize on the + // returned RemoteRendezvous + RemoteRendezvous* Find(int64 step_id) override; // Finds the local rendezvous instance for the "step_id". Runs // "done" when the tensor for "key" is produced or an error occurs. @@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { protected: virtual BaseRemoteRendezvous* Create(int64 step_id, - const WorkerEnv* worker_env, - const string& worker_name) = 0; + const WorkerEnv* worker_env) = 0; private: // Maps step_id to rendezvous. @@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { // Not owned. const WorkerEnv* const worker_env_; - const string worker_name_; mutex mu_; Table table_ GUARDED_BY(mu_); @@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { // Buffering of Tensor values is delegated to a "local" Rendezvous // obtained from NewLocalRendezvous(). This class just adds // functionality to coordinate with remote workers. -class BaseRemoteRendezvous : public Rendezvous { +class BaseRemoteRendezvous : public RemoteRendezvous { public: - BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name, - int64 step_id, bool tolerate_dup_recv); + BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id, + bool tolerate_dup_recv); + + // Upgrades the BaseRemoteRendezvous to full initialization. + Status Initialize(WorkerSession* session) override; // Forwards to local_, where the Tensor "val" will be buffered and // any waiting callback stored. @@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous { // Removes "call" from active_ if "call" is in active_. void DeregisterCall(BaseRecvTensorCall* call); + WorkerSession* session(); + + bool is_initialized(); + ~BaseRemoteRendezvous() override; const WorkerEnv* const env_; // Not owned. - const string worker_name_; const int64 step_id_; private: @@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous { // Status given by StartAbort() if any. Status status_ GUARDED_BY(mu_); + WorkerSession* session_ GUARDED_BY(mu_); // Not owned. + + // Data structures to handle calls when partially initialized. + struct DeferredCall { + const ParsedKey parsed; + DoneCallback done; + + DeferredCall(const ParsedKey& parsed, DoneCallback done); + }; + std::vector deferred_calls_ GUARDED_BY(mu_); // Active outstanding RecvTensor calls. gtl::FlatSet active_ GUARDED_BY(mu_); + bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return session_ != nullptr; + } + // If "is_src" is true, checks that the rendezvous key "parsed"'s // source is in this process. If "is_src" is false, checks that the // rendezvous key "parsed"'s destination is in this process. @@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous { const Rendezvous::Args& out_args, const Tensor& in, Tensor* out, StatusCallback done); + // Must be called only if fully initialized. + void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done); + TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); }; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index ce7ce372e85..5bde771e8de 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -46,10 +46,8 @@ limitations under the License. namespace tensorflow { -GraphMgr::GraphMgr(const WorkerEnv* worker_env, - RendezvousMgrInterface* rendezvous_mgr) - : worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) { - CHECK(rendezvous_mgr) << "Rendezvous mgr was null"; +GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr) + : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) { // The default value of sync_on_finish will be flipped soon and this // environment variable will be removed as well. Status status = @@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, }; popts.get_incarnation = [this](const string& name) -> int64 { Device* device = nullptr; - Status s = worker_env_->device_mgr->LookupDevice(name, &device); + Status s = device_mgr_->LookupDevice(name, &device); if (s.ok()) { return device->attributes().incarnation(); } else { @@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, ExecutionUnit* unit = &(item->units.back()); // Find the device. - Status s = - worker_env_->device_mgr->LookupDevice(device_name, &unit->device); + Status s = device_mgr_->LookupDevice(device_name, &unit->device); if (!s.ok()) { // Remove the empty unit from the item as the item destructor wants all // units to have valid devices. @@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, // Function library runtime. unit->lib = NewFunctionLibraryRuntime( - worker_env_->device_mgr, worker_env_->env, unit->device, + device_mgr_, worker_env_->env, unit->device, subgraph->versions().producer(), item->lib_def, graph_options.optimizer_options()); @@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, } Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = SendInputsToRendezvous(rendezvous, in); rendezvous->Unref(); return s; } Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = RecvOutputsFromRendezvous(rendezvous, out); rendezvous->Unref(); return s; @@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, StatusCallback done) { - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); RecvOutputsFromRendezvousAsync(rendezvous, out, [done, rendezvous](const Status s) { rendezvous->Unref(); @@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, } void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, - const ExecutorOpts& opts, + WorkerSession* session, + const ExecutorOpts& /*opts*/, StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, @@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, return; } - Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id); + RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); + Status s = rendezvous->Initialize(session); // Sends values specified by the caller. - Status s = SendInputsToRendezvous(rendezvous, in); + if (s.ok()) { + s = SendInputsToRendezvous(rendezvous, in); + } + if (!s.ok()) { done(s); item->Unref(); @@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, StatusCallback done) { const int num_units = item->units.size(); CHECK_GE(num_units, 1); - ScopedStepContainer* step_container = - new ScopedStepContainer(step_id, [this](const string& name) { - worker_env_->device_mgr->ClearContainers({name}); - }); + ScopedStepContainer* step_container = new ScopedStepContainer( + step_id, + [this](const string& name) { device_mgr_->ClearContainers({name}); }); // NOTE: Transfer one ref of rendezvous and item. ExecutorBarrier* barrier = new ExecutorBarrier(num_units, rendezvous, diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 349af6c54e5..50391f47e4d 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -37,6 +37,8 @@ namespace tensorflow { class ExecutorOpts; class StepStatsCollector; class RendezvousMgrInterface; +class DeviceMgr; +struct WorkerSession; // GraphMgr keeps track of a set of graphs that are registered with a // TensorFlow worker. Each registered graph is identified by a handle @@ -62,8 +64,7 @@ class RendezvousMgrInterface; // EXPECT_EQ(out["c"], Tensor({4, 6})); class GraphMgr { public: - explicit GraphMgr(const WorkerEnv* worker_env, - RendezvousMgrInterface* rendezvous_mgr); + explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr); ~GraphMgr(); // Registers a graph. Fills in "handle" @@ -78,8 +79,8 @@ class GraphMgr { typedef std::map NamedTensors; typedef std::function StatusCallback; void ExecuteAsync(const string& handle, const int64 step_id, - const ExecutorOpts& opts, StepStatsCollector* collector, - CostGraphDef* cost_graph, + WorkerSession* session, const ExecutorOpts& opts, + StepStatsCollector* collector, CostGraphDef* cost_graph, CancellationManager* cancellation_manager, const NamedTensors& in, StatusCallback done); @@ -131,7 +132,7 @@ class GraphMgr { }; const WorkerEnv* worker_env_; // Not owned. - RendezvousMgrInterface* rendezvous_mgr_; // Not owned. + DeviceMgr* device_mgr_; CostModelManager cost_model_manager_; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index b4adee3bf6c..e860c99d953 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -34,6 +34,7 @@ limitations under the License. #include #include +#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/remote_device.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" @@ -48,12 +49,17 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +namespace { +const char* const kGrpcProtocol = "grpc://"; +} // namespace + Master::Master(MasterEnv* env, double session_gc_seconds) : env_(env), last_1000_steps_(1000), @@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req, CreateSessionResponse* resp, MyClosure done) { SchedClosure([this, req, resp, done]() { Status status; + WorkerCacheFactoryOptions worker_cache_factory_options; + string grpc_protocol("grpc"); + worker_cache_factory_options.protocol = &grpc_protocol; auto call_done = gtl::MakeCleanup([&status, &done] { done(status); }); status = ValidateExternalGraphDefSyntax(req->graph_def()); if (!status.ok()) return; - // Ping all the workers and build the list of devices that the - // session will use. + + // The following 4 variables are set differently, depending on whether this + // session uses a client-provided clusterspec or not. + WorkerCacheInterface* worker_cache = nullptr; + // Note: worker_cache_ptr will be null except if this session is using a + // client-supplied ClusterDef (ClusterSpec propagation). + std::unique_ptr worker_cache_ptr; + std::unique_ptr device_set; // TODO(saeta): Convert to std::make_unique when available. std::unique_ptr>> remote_devices( new std::vector>()); - status = DeviceFinder::GetRemoteDevices(req->config().device_filters(), - env_, env_->worker_cache, - remote_devices.get()); - if (!status.ok()) return; + + if (req->config().has_cluster_def()) { + worker_cache_factory_options.cluster_def = &req->config().cluster_def(); + + // Set the server_def's job_name and task_index fields. + string normalized_string; + string grpc_protocol(kGrpcProtocol); + if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) == + 0) { + normalized_string = + req->target().substr(grpc_protocol.length(), string::npos); + } else { + normalized_string = req->target(); + } + for (auto&& job : req->config().cluster_def().job()) { + for (auto&& task : job.tasks()) { + if (task.second == normalized_string) { + if (worker_cache_factory_options.job_name != nullptr) { + status = errors::InvalidArgument( + "Found multiple matching tasks that correspond to " + "to the master. Master target: '", + req->target(), "'. ClusterDef: ", + req->config().cluster_def().ShortDebugString()); + LOG(ERROR) << status; + return; + } + if (env_->local_devices[0]->parsed_name().job == job.name() && + env_->local_devices[0]->parsed_name().task == task.first) { + // TODO(b/37868888): Remove this limitation when resolved + status = errors::InvalidArgument( + "The ClusterSpec names the job and task index to be the same " + "names that were provided when the server booted. This is " + "currently not allowed. Job: ", + job.name(), ", task index: ", task.first); + return; + } + worker_cache_factory_options.job_name = &job.name(); + worker_cache_factory_options.task_index = task.first; + } + } + } + + // Create the worker cache from the computed server_def. + status = env_->worker_cache_factory(worker_cache_factory_options, + &worker_cache); + if (!status.ok()) return; + worker_cache_ptr = std::unique_ptr(worker_cache); + // Ping all the workers and build the list of devices that the + // session will use. + status = + DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, + worker_cache, remote_devices.get()); + if (!status.ok()) return; + device_set.reset(new DeviceSet); + for (auto&& d : *remote_devices) { + device_set->AddDevice(d.get()); + DeviceNameUtils::ParsedName name = d->parsed_name(); + if (name.job == *worker_cache_factory_options.job_name && + name.task == worker_cache_factory_options.task_index && + name.type == "CPU") { + device_set->set_client_device(d.get()); + } + } + } else { + worker_cache = env_->worker_cache; + // Ping all the workers and build the list of devices that the + // session will use. + status = + DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, + worker_cache, remote_devices.get()); + if (!status.ok()) return; + device_set.reset(new DeviceSet); + for (auto&& d : *remote_devices) { + device_set->AddDevice(d.get()); + } + int num_local_devices = 0; + for (Device* d : env_->local_devices) { + device_set->AddDevice(d); + if (num_local_devices == 0) { + // Uses the first local device as the client device. + device_set->set_client_device(d); + } + num_local_devices++; + } + } + + CHECK(device_set->client_device()); + SessionOptions options; options.config = req->config(); - MasterSession* session = - env_->master_session_factory(options, env_, std::move(remote_devices)); + + MasterSession* session = env_->master_session_factory( + options, env_, std::move(remote_devices), std::move(worker_cache_ptr), + std::move(device_set)); + GraphDef* gdef = const_cast(req)->mutable_graph_def(); - status = session->Create(gdef); + + status = session->Create(gdef, worker_cache_factory_options); if (!status.ok()) { session->Close().IgnoreError(); session->Unref(); diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h index a155bd384d8..bb548adda15 100644 --- a/tensorflow/core/distributed_runtime/master_env.h +++ b/tensorflow/core/distributed_runtime/master_env.h @@ -19,17 +19,41 @@ limitations under the License. #include #include -#include "tensorflow/core/distributed_runtime/master_session.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { class Device; +class DeviceSet; class Env; class MasterSession; class OpRegistryInterface; class WorkerCacheInterface; +// Options passed to the worker_cache_factory function. +struct WorkerCacheFactoryOptions { + const ClusterDef* cluster_def = nullptr; + const string* job_name = nullptr; + int task_index; + const string* protocol = nullptr; + + WorkerCacheFactoryOptions() {} + + // Construct from a ServerDef proto. + // + // Note: server_def must outlive WorkerCacheFactoryOptions! + WorkerCacheFactoryOptions(const ServerDef& server_def) { + if (server_def.has_cluster() && !server_def.job_name().empty()) { + cluster_def = &server_def.cluster(); + job_name = &server_def.job_name(); + task_index = server_def.task_index(); + protocol = &server_def.protocol(); + } + } +}; + // The master environment class, which holds a bag of pointers to // per-master state. // @@ -57,8 +81,14 @@ struct MasterEnv { // `MasterEnv*` is retained by the caller. std::function>>)> + std::unique_ptr>>, + std::unique_ptr, + std::unique_ptr device_set)> master_session_factory; + + std::function + worker_cache_factory; }; } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 5257aea1e3a..50c5d90fc98 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -36,11 +36,13 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -528,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( c->req->set_is_partial(is_partial_); c->req->set_is_last_partial_run(is_last_partial_run); } + c->req->set_session_handle(session_handle_); c->req->set_graph_handle(part.graph_handle); c->req->set_step_id(step_id); *c->req->mutable_exec_opts() = exec_opts; @@ -871,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() { // The graph handle may be empty if we failed during partition registration. if (!part.graph_handle.empty()) { Call* c = new Call; + c->req.set_session_handle(session_handle_); c->req.set_graph_handle(part.graph_handle); // NOTE(mrry): We must capture `worker_cache_` since `this` // could be deleted before the callback is called. @@ -973,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) { MasterSession::MasterSession( const SessionOptions& opt, const MasterEnv* env, std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set, StatsPublisherFactory stats_publisher_factory) : session_opts_(opt), env_(env), handle_(strings::FpToString(random::New64())), remote_devs_(std::move(remote_devs)), + worker_cache_(std::move(worker_cache)), + devices_(std::move(device_set)), stats_publisher_factory_(std::move(stats_publisher_factory)), graph_version_(0), run_graphs_(5), partial_run_graphs_(5) { UpdateLastAccessTime(); + CHECK(devices_) << "device_set was null!"; VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size() << " #remote " << remote_devs_->size(); - for (auto&& d : *remote_devs_) { - devices_.AddDevice(d.get()); - } - int num_local_devices = 0; - for (Device* d : env->local_devices) { - devices_.AddDevice(d); - if (num_local_devices == 0) { - // Uses the first local device as the client device. - devices_.set_client_device(d); - } - num_local_devices++; - } + LOG(INFO) << "Start master session " << handle_ << " with config: " << std::endl << session_opts_.config.DebugString(); @@ -1012,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() { last_access_time_usec_.store(Env::Default()->NowMicros()); } -Status MasterSession::Create(GraphDef* graph_def) { +Status MasterSession::Create(GraphDef* graph_def, + const WorkerCacheFactoryOptions& options) { if (session_opts_.config.graph_options().place_pruned_graph()) { // TODO(b/29900832): Fix this or remove the option. LOG(WARNING) << "Distributed session does not support the " @@ -1020,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) { session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false); } - SimpleGraphExecutionStateOptions options; - options.device_set = &devices_; - options.session_options = &session_opts_; + SimpleGraphExecutionStateOptions execution_options; + execution_options.device_set = devices_.get(); + execution_options.session_options = &session_opts_; { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph( - graph_def, options, &execution_state_)); + graph_def, execution_options, &execution_state_)); + } + if (options.cluster_def != nullptr) { + return CreateWorkerSessions(options); } return Status::OK(); } +Status MasterSession::CreateWorkerSessions( + const WorkerCacheFactoryOptions& options) { + CHECK(worker_cache_) << "CreateWorkerSessions should be called only with " + << "dynamic cluster membership."; + std::vector worker_names; + worker_cache_->ListWorkers(&worker_names); + + struct WorkerGroup { + // The worker name. (Not owned.) + const string* name; + + // The worker referenced by name. (Not owned.) + WorkerInterface* worker = nullptr; + + // Request and responses used for a given worker. + CreateWorkerSessionRequest request; + CreateWorkerSessionResponse response; + Status status = Status::OK(); + }; + BlockingCounter done(worker_names.size()); + std::vector workers(worker_names.size()); + + // Release the workers. + auto cleanup = gtl::MakeCleanup([this, &workers] { + for (auto&& worker_group : workers) { + if (worker_group.worker != nullptr) { + worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker); + } + } + }); + + Status status = Status::OK(); + // Create all the workers & kick off the computations. + for (size_t i = 0; i < worker_names.size(); ++i) { + workers[i].name = &worker_names[i]; + workers[i].worker = worker_cache_->CreateWorker(worker_names[i]); + workers[i].request.set_session_handle(handle_); + *workers[i].request.mutable_server_def()->mutable_cluster() = + *options.cluster_def; + workers[i].request.mutable_server_def()->set_protocol(*options.protocol); + + DeviceNameUtils::ParsedName name; + if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) { + status = errors::Internal("Could not parse name ", worker_names[i]); + LOG(WARNING) << status; + return status; + } + if (!name.has_job || !name.has_task) { + status = errors::Internal("Incomplete worker name ", worker_names[i]); + LOG(WARNING) << status; + return status; + } + + workers[i].request.mutable_server_def()->set_job_name(name.job); + workers[i].request.mutable_server_def()->set_task_index(name.task); + } + + for (size_t i = 0; i < worker_names.size(); ++i) { + auto cb = [i, &workers, &done](const Status& s) { + workers[i].status = s; + done.DecrementCount(); + }; + workers[i].worker->CreateWorkerSessionAsync(&workers[i].request, + &workers[i].response, cb); + } + + done.Wait(); + for (size_t i = 0; i < workers.size(); ++i) { + status.Update(workers[i].status); + } + return status; +} + Status MasterSession::Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp) { UpdateLastAccessTime(); @@ -1060,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req, return Status::OK(); } +WorkerCacheInterface* MasterSession::get_worker_cache() const { + if (worker_cache_) { + return worker_cache_.get(); + } + return env_->worker_cache; +} + Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** rcg, bool is_partial) { const uint64 hash = HashBuildGraphOptions(opts); @@ -1083,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, << "\n"; std::unique_ptr client_graph; TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); + WorkerCacheInterface* worker_cache = get_worker_cache(); auto entry = new ReffedClientGraph( handle_, opts, std::move(client_graph), session_opts_, stats_publisher_factory_, execution_state_.get(), is_partial, - env_->worker_cache); - + worker_cache); iter = m->insert({hash, entry}).first; VLOG(1) << "Preparing to execute new graph"; } @@ -1162,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, return errors::FailedPrecondition("Session is closed."); } ++num_running_; + // Note: all code paths must eventually call MarkRunCompletion() + // in order to appropriate decrement the num_running_ counter. } Status status; if (!req.partial_run_handle().empty()) { @@ -1169,16 +1253,18 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, } else { status = DoRunWithLocalExecution(opts, req, resp); } - { - mutex_lock l(mu_); - --num_running_; - if (num_running_ == 0) { - num_running_is_zero_.notify_all(); - } - } return status; } +// Decrements num_running_ and broadcasts if num_running_ is zero. +void MasterSession::MarkRunCompletion() { + mutex_lock l(mu_); + --num_running_; + if (num_running_ == 0) { + num_running_is_zero_.notify_all(); + } +} + Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { // Registers subgraphs if haven't done so. PartitionOptions popts; @@ -1188,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { return strings::StrCat(prefix, "_S", next_node_id_++); }; popts.get_incarnation = [this](const string& name) -> int64 { - Device* d = devices_.FindDeviceByName(name); + Device* d = devices_->FindDeviceByName(name); if (d == nullptr) { return PartitionOptions::kIllegalIncarnation; } else { @@ -1223,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp) { + auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); const string& prun_handle = req.partial_run_handle(); RunState* run_state = nullptr; { @@ -1321,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts, rcg->Ref(); rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(), req.options(), resp->mutable_metadata()); + cleanup.release(); // MarkRunCompletion called in done closure. rcg->CleanupPartitionsAsync( run_state->step_id, [this, rcg, prun_handle](const Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; } rcg->Unref(); + MarkRunCompletion(); }); mutex_lock l(mu_); partial_runs_.erase(prun_handle); @@ -1368,10 +1457,10 @@ Status MasterSession::CreateDebuggerState( Status MasterSession::DoRunWithLocalExecution( CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp) { - VLOG(2) << "DoRunWithLocalExecution " - << "req: " << req.DebugString(); + VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString(); PerStepState pss; pss.start_micros = Env::Default()->NowMicros(); + auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); // Prepare. BuildGraphOptions bgopts; @@ -1438,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution( } } rcg->Ref(); - rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) { + cleanup.release(); // MarkRunCompletion called in done closure. + rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; } rcg->Unref(); + MarkRunCompletion(); }); return s; } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index d47125be992..3acc5bc5f0a 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/master.pb.h" @@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted { MasterSession( const SessionOptions& options, const MasterEnv* env, std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set, StatsPublisherFactory stats_publisher_factory); // Initialize the MasterSession for "def". Must be called before Extend(), // Run(), or Close(). // // After this method returns, `def` will no longer be valid. - Status Create(GraphDef* def); + Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options); // Returns the session handle. const string& handle() const { return handle_; } @@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted { std::unique_ptr>> remote_devs_; + // The optional session-specific worker cluster. + // TODO(saeta): Convert to std::optional when available. + std::unique_ptr worker_cache_; + // Retrieves either worker_cache_ or the env_->worker_cache as appropriate. + WorkerCacheInterface* get_worker_cache() const; + // The device set used by this session. - DeviceSet devices_; + std::unique_ptr devices_; StatsPublisherFactory stats_publisher_factory_; @@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted { // Private dtor. The client must call Close(). virtual ~MasterSession(); + // Creates sessions on all workers. + // + // If this session is operating using the new ClusterSpec propagation behavior + // call this method in order to propagate the cluster membership to all + // workers. + Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); + Status StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** graph, bool is_partial); void ClearRunsTable(std::vector* to_unref, @@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted { MutableRunStepResponseWrapper* resp); Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp); + void MarkRunCompletion(); void UpdateLastAccessTime(); Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index 7b58feb93cc..b077975ea50 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const { const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; } +const string& InMemoryRunGraphRequest::session_handle() const { + return session_handle_; +} + +void InMemoryRunGraphRequest::set_session_handle(const string& handle) { + session_handle_ = handle; +} + const string& InMemoryRunGraphRequest::graph_handle() const { return graph_handle_; } @@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run( const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { if (!proto_version_) { proto_version_.reset(new RunGraphRequest); + proto_version_->set_session_handle(session_handle()); proto_version_->set_graph_handle(graph_handle()); proto_version_->set_step_id(step_id()); *proto_version_->mutable_exec_opts() = exec_opts(); @@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { return *proto_version_; } +const string& MutableProtoRunGraphRequest::session_handle() const { + return request_.session_handle(); +} + +void MutableProtoRunGraphRequest::set_session_handle(const string& handle) { + request_.set_session_handle(handle); +} + const string& MutableProtoRunGraphRequest::graph_handle() const { return request_.graph_handle(); } @@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const { ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request) : request_(request) {} +const string& ProtoRunGraphRequest::session_handle() const { + return request_->session_handle(); +} + const string& ProtoRunGraphRequest::graph_handle() const { return request_->graph_handle(); } diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 02516eabb4a..795a6add0e7 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -223,6 +223,10 @@ class RunGraphRequestWrapper { public: virtual ~RunGraphRequestWrapper() {} + // The session handle used to register the graph. If empty, a single global + // namespace is used. + virtual const string& session_handle() const = 0; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. virtual const string& graph_handle() const = 0; @@ -262,6 +266,7 @@ class RunGraphRequestWrapper { // See `RunGraphRequestWrapper` above for a description of the fields. class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { public: + virtual void set_session_handle(const string& handle) = 0; virtual void set_graph_handle(const string& handle) = 0; virtual void set_step_id(int64 step_id) = 0; virtual ExecutorOpts* mutable_exec_opts() = 0; @@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { public: // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; @@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { const RunGraphRequest& ToProto() const override; // MutableRunGraphRequestWrapper methods. + void set_session_handle(const string& handle) override; void set_graph_handle(const string& handle) override; void set_step_id(int64 step_id) override; ExecutorOpts* mutable_exec_opts() override; @@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { void set_is_last_partial_run(bool is_last_partial_run) override; private: + string session_handle_; string graph_handle_; int64 step_id_; ExecutorOpts exec_opts_; @@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { public: // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; @@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { const RunGraphRequest& ToProto() const override; // MutableRunGraphRequestWrapper methods. + void set_session_handle(const string& handle) override; void set_graph_handle(const string& handle) override; void set_step_id(int64 step_id) override; ExecutorOpts* mutable_exec_opts() override; @@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper { ProtoRunGraphRequest(const RunGraphRequest* request); // RunGraphRequestWrapper methods. + const string& session_handle() const override; const string& graph_handle() const override; int64 step_id() const override; const ExecutorOpts& exec_opts() const override; diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index 9632e9c4398..91c1fb99fef 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/remote_device.h" #include + #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) { class RemoteDevice : public Device { public: RemoteDevice(Env* env, const DeviceAttributes& da) - : Device(env, da, nullptr), - local_dev_name_(GetLocalDeviceName(da.name())) {} + : Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {} Status Sync() override { return Status::OK(); } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } @@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, GetStatusResponse resp; }; Call* call = new Call; - auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) { + auto cb = [env, worker_cache, worker_name, done, wi, + call](const Status& status) { + Status s = status; std::vector remote_devices; + auto cleanup = gtl::MakeCleanup( + [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] { + worker_cache->ReleaseWorker(worker_name, wi); + done(s, &remote_devices); + delete call; + }); if (s.ok()) { + DeviceNameUtils::ParsedName worker_name_parsed; + if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) || + !worker_name_parsed.has_job || !worker_name_parsed.has_replica || + !worker_name_parsed.has_task) { + s = errors::InvalidArgument("Could not parse worker name: ", + worker_name); + LOG(WARNING) << s; + return; + } remote_devices.reserve(call->resp.device_attributes_size()); for (const DeviceAttributes& da : call->resp.device_attributes()) { - auto d = new RemoteDevice(env, da); - remote_devices.push_back(d); + DeviceNameUtils::ParsedName device_name_parsed; + CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed)) + << "Device attribute name '" << da.name() << "' could not be " + << "parsed. Device Attribute: " << da.DebugString(); + // Preserve the exact name, if possible. + // TODO(b/37868888): Simplify when legacy device name formats removed. + if (device_name_parsed.job == worker_name_parsed.job && + device_name_parsed.replica == worker_name_parsed.replica && + device_name_parsed.task == worker_name_parsed.task) { + auto d = new RemoteDevice(env, da); + remote_devices.push_back(d); + } else { + DeviceAttributes da_rewritten = da; + da_rewritten.set_name(DeviceNameUtils::FullName( + worker_name_parsed.job, worker_name_parsed.replica, + worker_name_parsed.task, device_name_parsed.type, + device_name_parsed.id)); + auto d = new RemoteDevice(env, da_rewritten); + remote_devices.push_back(d); + } } } - worker_cache->ReleaseWorker(worker_name, wi); - done(s, &remote_devices); - delete call; }; wi->GetStatusAsync(&call->req, &call->resp, cb); } diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h index 04c1fc248ef..43267d4362f 100644 --- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h +++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -25,6 +25,23 @@ limitations under the License. namespace tensorflow { +struct WorkerSession; + +// RemoteRendezvous follow a 2-part initialization. First the objects are +// constructed. Eventually, they will be initialized. Clients of the +// RendezvousMgrInterface must guarantee to call Initialize on the returned +// RemoteRendezvous eventually. +// +// Partially initialized RemoteRendezvous must respect the Rendezvous interface +// (i.e. Send() must never block), however implementations are not expected to +// actually perform the underlying operations until after the RemoteRendezvous +// has been Initialize'd. +class RemoteRendezvous : public Rendezvous { + public: + // Fully construct the RemoteRendezvous. + virtual Status Initialize(WorkerSession* session) = 0; +}; + // RendezvousMgr keeps track of a set of local rendezvous instances. // All tensors sent by this worker are buffered in a RendezvousMgr // until the tensor is received. Each global unique "step_id" @@ -51,7 +68,10 @@ class RendezvousMgrInterface { // Returns Rendezvous supporting send and recv among workers in the // "step_id". The caller takes ownership of one reference on the // returned Rendezvous instance. - virtual Rendezvous* Find(int64 step_id) = 0; + // + // Note: the caller must guarantee to eventually call Initialize on the + // returned RemoteRendezvous + virtual RemoteRendezvous* Find(int64 step_id) = 0; // Finds the local rendezvous instance for the "step_id". Runs // "done" when the tensor for "key" is produced or an error occurs. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 7160962b168..3867dd1f4d0 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption { }; // static utility function -RendezvousMgrInterface* NewRpcRendezvousMgr( - const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache) { - return new RpcRendezvousMgr(env, worker_name, worker_cache); +RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) { + return new RpcRendezvousMgr(env); } } // namespace @@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() { // TODO(mrry): Refactor the *Env classes so that it is less fiddly // to destroy them. + // Shut down all outstanding rendezvous. + delete worker_env_.rendezvous_mgr; + // We must delete graph_mgr before device_mgr, due to shared // ownership of OpKernels in the executors. (The graph_mgr will // free all stateless OpKernels, and pass over borrowed stateful @@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() { // OpSegments.) if (worker_env_.session_mgr != nullptr) { delete worker_env_.session_mgr; // Deletes graph_mgr's. + } else { + // Note: session_mgr's legacy_session_ deletes device_mgr now. + delete worker_env_.device_mgr; } - delete worker_env_.device_mgr; // Do not delete (as these are not owned by the server): // - master_env_.env @@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() { // - worker_env_.compute_pool } -Status GrpcServer::Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendevous_mgr_func) { +Status GrpcServer::Init( + ServiceInitFunction service_func, + const RendezvousMgrCreationFunction& rendezvous_mgr_func) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func, "/task:", server_def_.task_index()); TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix, &master_env_.local_devices)); - worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices); + worker_env_.local_devices = master_env_.local_devices; + worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices); + worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr + ? new RpcRendezvousMgr(&worker_env_) + : rendezvous_mgr_func(&worker_env_); string unused; string default_worker_name; if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(), @@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func, } WorkerCacheInterface* worker_cache; - TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache)); + WorkerCacheFactoryOptions worker_cache_factory_options(server_def_); + TF_RETURN_IF_ERROR( + WorkerCacheFactory(worker_cache_factory_options, &worker_cache)); CHECK_NE(nullptr, worker_cache); // Set up worker environment. - std::unique_ptr rendezvous_mgr( - rendevous_mgr_func == nullptr ? - new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) : - rendevous_mgr_func(&worker_env_, name_prefix, worker_cache)); worker_env_.session_mgr = new SessionMgr( &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), std::unique_ptr(worker_cache), - std::move(rendezvous_mgr), [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { - return WorkerCacheFactory(server_def, worker_cache); + WorkerCacheFactoryOptions options(server_def); + return WorkerCacheFactory(options, worker_cache); }); worker_env_.compute_pool = ComputePool(sess_opts); @@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func, master_env_.master_session_factory = [config]( SessionOptions options, const MasterEnv* env, - std::unique_ptr>> remote_devs) { + std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set) { options.config.MergeFrom(config); return new MasterSession(options, env, std::move(remote_devs), + std::move(worker_cache), std::move(device_set), CreateNoOpStatsPublisher); }; + master_env_.worker_cache_factory = + [this](const WorkerCacheFactoryOptions& options, + WorkerCacheInterface** worker_cache) { + return WorkerCacheFactory(options, worker_cache); + }; // Provide direct access to the master from in-process clients. LocalMaster::Register(target(), master_impl_.get(), @@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func, return Status::OK(); } -Status GrpcServer::Init() { - return Init(nullptr, nullptr); -} +Status GrpcServer::Init() { return Init(nullptr, nullptr); } -Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, +Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec) { - for (const auto& job : server_def.cluster().job()) { + for (const auto& job : options.cluster_def->job()) { std::map host_ports; for (const auto& task : job.tasks()) { string& host_port = host_ports[task.first]; @@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, task.first, "\": ", host_port, " and ", task.second); } - if (job.name() == server_def.job_name() && - task.first == server_def.task_index()) { + if (job.name() == *options.job_name && task.first == options.task_index) { host_port = strings::StrCat("localhost:", bound_port_); } else { host_port = task.second; @@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def, return Status::OK(); } -Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def, +Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache) { - string name_prefix = - strings::StrCat("/job:", server_def.job_name(), "/replica:0", - "/task:", server_def.task_index()); + if (options.job_name == nullptr || options.job_name->empty()) { + Status s = errors::InvalidArgument( + "The master (current machine) is not included in the provided " + "cluster_def. ", + options.cluster_def->DebugString()); + LOG(WARNING) << s; + return s; + } GrpcChannelSpec channel_spec; - TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); + TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec)); + + std::unique_ptr channel_cache( + NewGrpcChannelCache(channel_spec, GetChannelCreationFunction())); + + string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0", + "/task:", options.task_index); - std::unique_ptr channel_cache(NewGrpcChannelCache( - channel_spec, GetChannelCreationFunction(server_def))); const string host_port = channel_cache->TranslateTask(name_prefix); int requested_port; @@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials( return ::grpc::InsecureServerCredentials(); } -ChannelCreationFunction GrpcServer::GetChannelCreationFunction( - const ServerDef& server_def) const { +ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const { // We can do this because SparseGrpcChannelCache is robust to nullptr being // returned by the channel creation function return ConvertToChannelCreationFunction(NewHostPortGrpcChannel); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 3b66291a9ab..7b54bb84c88 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -37,9 +37,7 @@ class GrpcWorker; class Master; // function that creates a RendezvousMgr. -typedef std::function +typedef std::function RendezvousMgrCreationFunction; // function that registers a service to the server. The service needs to @@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface { protected: Status Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendezvous_mgr_func); + const RendezvousMgrCreationFunction& rendezvous_mgr_func); Status Init(); @@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface { virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( const ServerDef& server_def) const; - virtual ChannelCreationFunction GetChannelCreationFunction( - const ServerDef& server_def) const; + virtual ChannelCreationFunction GetChannelCreationFunction() const; virtual std::unique_ptr CreateMaster(MasterEnv* master_env); // Creates a WorkerCacheInterface for a session. - Status WorkerCacheFactory(const ServerDef& server_def, + Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache); - // Parses a ServerDef into a GrpcChannelSpec. - Status ParseChannelSpec(const ServerDef& server_def, + // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. + Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec); // Returns the port to which this server is bound. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 1aacef8a26a..38d59d5bb59 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix); /* static */ Status GrpcSession::Create(const SessionOptions& options, std::unique_ptr* out_session) { - std::unique_ptr ret(new GrpcSession(options)); + std::unique_ptr session(new GrpcSession(options)); std::unique_ptr master; // For testing, we enable the client to disable the use of the local // master registry, so that the RPC stack is exercised. @@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options, options.target.substr(kSchemePrefixLength), &master_channel)); master.reset(NewGrpcMaster(master_channel)); } - ret->SetRemoteMaster(std::move(master)); - *out_session = std::move(ret); + session->SetRemoteMaster(std::move(master)); + *out_session = std::move(session); return Status::OK(); } @@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options, CreateSessionRequest req; *req.mutable_config() = options_.config; *req.mutable_graph_def() = graph; + req.set_target(options_.target); ReEncodeConsts(req.mutable_graph_def()); CreateSessionResponse resp; Status s = master_->CreateSession(call_options, &req, &resp); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index c11266587d8..873ef8588f4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface { // completes, and we may decide to bound some of the request // types. ENQUEUE_REQUEST(GetStatus, false); + ENQUEUE_REQUEST(CreateWorkerSession, false); ENQUEUE_REQUEST(CleanupAll, false); ENQUEUE_REQUEST(RegisterGraph, false); ENQUEUE_REQUEST(DeregisterGraph, false); @@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface { ENQUEUE_REQUEST(GetStatus, false); } + void CreateWorkerSessionHandler( + WorkerCall* + call) { + Schedule([this, call]() { + Status s = worker_->CreateWorkerSession(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(CreateWorkerSession, false); + } + void CleanupAllHandler( WorkerCall* call) { Schedule([this, call]() { @@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts, ::grpc::ByteBuffer* response, StatusCallback done) { const int64 step_id = request->step_id(); - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); const string& key = request->rendezvous_key(); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); Rendezvous::ParsedKey parsed; @@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts, // of execution of the callback lambda body below, an RPC // cancellation should abort the rendezvous. opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); - session->rendezvous_mgr->RecvLocalAsync( + env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, [opts, response, done, src_dev](const Status& status, const Rendezvous::Args& send_args, diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 7518a289fdb..8265100061e 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -38,9 +38,8 @@ namespace { class RpcRemoteRendezvous : public BaseRemoteRendezvous { public: - RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* cache, int64 step_id) - : BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {} + RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id) + : BaseRemoteRendezvous(env, step_id, false) {} protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, @@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { private: ~RpcRemoteRendezvous() override {} - WorkerCacheInterface* const cache_; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous); }; @@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() { return call_freelist; } -// A private cache that wraps worker_cache and allows reuse of -// WorkerInterface objects. -class WorkerFreeListCache : public WorkerCacheInterface { - public: - explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {} - - ~WorkerFreeListCache() { - for (auto p : workers_) { - wrapped_->ReleaseWorker(p.first, p.second.worker); - } - } - - void ListWorkers(std::vector* workers) const override { - wrapped_->ListWorkers(workers); - } - - WorkerInterface* CreateWorker(const string& target) override { - mutex_lock l(mu_); - auto p = workers_.find(target); - if (p != workers_.end()) { - return p->second.worker; - } - WorkerState state; - state.worker = wrapped_->CreateWorker(target); - if (state.worker != nullptr) { - workers_.insert(std::make_pair(target, state)); - } - return state.worker; - } - - void ReleaseWorker(const string& target, WorkerInterface* worker) override { - // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. - } - - bool GetDeviceLocalityNonBlocking(const string& device, - DeviceLocality* locality) override { - return wrapped_->GetDeviceLocalityNonBlocking(device, locality); - } - - void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, - StatusCallback done) override { - wrapped_->GetDeviceLocalityAsync(device, locality, done); - } - - void SetLogging(bool active) override { wrapped_->SetLogging(active); } - - void ClearLogs() override { wrapped_->ClearLogs(); } - - bool RetrieveLogs(int64 step_id, StepStats* ss) override { - return wrapped_->RetrieveLogs(step_id, ss); - } - - private: - WorkerCacheInterface* wrapped_; - - // Information kept per created WorkerInterface. - struct WorkerState { - WorkerInterface* worker; - // TODO(jeff,sanjay): Add reference count if we support eviction. - }; - - // TODO(jeff,sanjay): Eviction when the map becomes too big. - mutex mu_; - std::unordered_map workers_ GUARDED_BY(mu_); -}; - void RpcRemoteRendezvous::RecvFromRemoteAsync( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { + CHECK(is_initialized()); Status s; // Prepare a RecvTensor call that can handle being aborted. @@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( s = errors::Internal(parsed.src_device, " is invalid remote source device."); } - WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_); + WorkerSession* sess = session(); + WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_); if (s.ok() && rwi == nullptr) { s = errors::Internal("No worker known as ", call->src_worker_); } Device* dst_device; if (s.ok()) { - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); } if (!s.ok()) { - get_call_freelist()->Release(call, cache_); + if (rwi != nullptr) { + sess->worker_cache->ReleaseWorker(call->src_worker_, rwi); + } + get_call_freelist()->Release(call, sess->worker_cache.get()); done(s, Args(), recv_args, Tensor{}, false); return; } @@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( // current status should be bad. Status s = call->status(); call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); - cache_->ReleaseWorker(call->src_worker_, call->wi_); + session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_); call->wi_ = nullptr; - get_call_freelist()->Release(call, cache_); + get_call_freelist()->Release(call, session()->worker_cache.get()); Unref(); }); } } // namespace -RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env, - const string& worker_name, - WorkerCacheInterface* worker_cache) - : BaseRendezvousMgr(env, worker_name), - cache_(new WorkerFreeListCache(worker_cache)) {} +RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env) + : BaseRendezvousMgr(env) {} BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, - const WorkerEnv* worker_env, - const string& worker_name) { - return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(), - step_id); + const WorkerEnv* worker_env) { + return new RpcRemoteRendezvous(worker_env, step_id); } } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h index 75dc62d98fd..34c48a79177 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h @@ -17,13 +17,13 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { +class DeviceMgr; + // RendezvousMgr keeps track of a set of local rendezvous instances. // All tensors sent by this worker are buffered in a RendezvousMgr // until the tensor is received. Each global unique "step_id" @@ -44,17 +44,12 @@ namespace tensorflow { // RendezvousMgr must have keys generated by Rendezvous::CreateKey. class RpcRendezvousMgr : public BaseRendezvousMgr { public: - explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache); + explicit RpcRendezvousMgr(const WorkerEnv* env); protected: - BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env, - const string& session_name) override; + BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env); private: - // Private cache_ that allows us to reuse WorkerInterface objects. - std::unique_ptr cache_; - TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr); }; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 9b778eab3a5..2d0d76623d4 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test { : cache_(new DummyWorkerCache), worker_session_("/job:mnist/replica:1/task:2", std::unique_ptr(cache_), - std::unique_ptr(), + std::unique_ptr(), std::unique_ptr()), - rmgr_(&env, worker_session_.worker_name, cache_) { + rmgr_(&env) { env.env = Env::Default(); } @@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) { "/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); core::ScopedUnref unref(rendez); Rendezvous::Args args; TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); @@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { // Explicit Abort(). const int64 step_id = 123; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); SchedClosure([this, rendez]() { env.env->SleepForMicroseconds(100 * 1000); @@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); } { // Cleanup causes Abort(). const int64 step_id = 321; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); SchedClosure([this, step_id]() { env.env->SleepForMicroseconds(100 * 1000); @@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) { Tensor val(DT_STRING); bool val_dead = false; Rendezvous::Args args; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead))); } } @@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) { "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { const int64 step_id = 123; - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); core::ScopedUnref unref(rendez); Rendezvous::Args args; TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); @@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) { "/job:mnist/replica:1/task:2/cpu:0", 7890, "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); { - Rendezvous* rendez = rmgr_.Find(step_id); + RemoteRendezvous* rendez = rmgr_.Find(step_id); core::ScopedUnref unref(rendez); Rendezvous::Args args; args.device_context = dc; + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false)); } { diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index e2be62f816c..22551d54821 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -17,8 +17,9 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -26,23 +27,12 @@ namespace tensorflow { SessionMgr::SessionMgr( WorkerEnv* worker_env, const string& default_worker_name, std::unique_ptr default_worker_cache, - std::unique_ptr default_rendezvous_mgr, - WorkerCacheFactory worker_cache_factory) - : SessionMgr( - worker_env, default_worker_name, std::move(default_worker_cache), - default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {} - -SessionMgr::SessionMgr( - WorkerEnv* worker_env, const string& default_worker_name, - std::unique_ptr default_worker_cache, - RendezvousMgrInterface* default_rendezvous_mgr, WorkerCacheFactory worker_cache_factory) : worker_env_(worker_env), - legacy_session_( - default_worker_name, std::move(default_worker_cache), - std::unique_ptr(default_rendezvous_mgr), - std::unique_ptr( - new GraphMgr(worker_env, default_rendezvous_mgr))), + legacy_session_(default_worker_name, std::move(default_worker_cache), + std::unique_ptr(worker_env->device_mgr), + std::unique_ptr( + new GraphMgr(worker_env, worker_env->device_mgr))), worker_cache_factory_(std::move(worker_cache_factory)) {} string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { @@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { Status SessionMgr::CreateSession(const string& session, const ServerDef& server_def) { mutex_lock l(mu_); + if (session.empty()) { + return errors::InvalidArgument("Session must be non-empty."); + } + const string worker_name = WorkerNameFromServerDef(server_def); WorkerCacheInterface* worker_cache = nullptr; TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); - std::unique_ptr rendezvous_mgr( - new RpcRendezvousMgr(worker_env_, worker_name, worker_cache)); + std::vector renamed_devices; + for (Device* d : worker_env_->local_devices) { + renamed_devices.push_back( + RenamedDevice::NewRenamedDevice(worker_name, d, false)); + } + std::unique_ptr device_mgr(new DeviceMgr(renamed_devices)); std::unique_ptr graph_mgr( - new GraphMgr(worker_env_, rendezvous_mgr.get())); + new GraphMgr(worker_env_, device_mgr.get())); std::unique_ptr worker_session(new WorkerSession( worker_name, std::unique_ptr(worker_cache), - std::move(rendezvous_mgr), std::move(graph_mgr))); + std::move(device_mgr), std::move(graph_mgr))); sessions_.insert(std::make_pair(session, std::move(worker_session))); return Status::OK(); @@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) { if (it != sessions_.end()) { sessions_.erase(it); } - std::set graph_handles; - for (auto graph_handle_it = sessions_by_graph_handle_.begin(); - graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) { - if (graph_handle_it->second == session) { - graph_handles.insert(graph_handle_it->first); - graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it); - if (graph_handle_it == sessions_by_graph_handle_.end()) break; - } - } - for (auto step_id_it = graphs_by_step_id_.begin(); - step_id_it != graphs_by_step_id_.end(); ++step_id_it) { - if (graph_handles.find(step_id_it->second) != graph_handles.end()) { - step_id_it = graphs_by_step_id_.erase(step_id_it); - if (step_id_it == graphs_by_step_id_.end()) break; - } - } return Status::OK(); } @@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) { WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; } -WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked( - const string& graph_handle) { - auto it = sessions_by_graph_handle_.find(graph_handle); - if (it == sessions_by_graph_handle_.end()) { - return &legacy_session_; - } else { - return WorkerSessionForSessionUnlocked(it->second); - } -} - -WorkerSession* SessionMgr::WorkerSessionForGraphHandle( - const string& graph_handle) { - mutex_lock l(mu_); - return WorkerSessionForGraphHandleUnlocked(graph_handle); -} - -WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) { - mutex_lock l(mu_); - auto it = graphs_by_step_id_.find(step_id); - if (it == graphs_by_step_id_.end()) { - return &legacy_session_; - } else { - return WorkerSessionForGraphHandleUnlocked(it->second); - } -} - -void SessionMgr::AssociateGraphWithSession(const string& session, - const string& graph_handle) { - mutex_lock l(mu_); - sessions_by_graph_handle_[graph_handle] = session; -} - -void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) { - mutex_lock l(mu_); - auto it = sessions_by_graph_handle_.find(graph_handle); - if (it != sessions_by_graph_handle_.end()) { - sessions_by_graph_handle_.erase(it); - } -} - -void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle, - const int64 step_id) { - mutex_lock l(mu_); - graphs_by_step_id_[step_id] = graph_handle; -} - -void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) { - mutex_lock l(mu_); - auto it = graphs_by_step_id_.find(step_id); - if (it != graphs_by_step_id_.end()) { - graphs_by_step_id_.erase(it); - } -} - } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 455b5c8d9d9..c44bca7b7a4 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -30,6 +30,8 @@ struct WorkerEnv; // SessionMgr keeps track of information related to a given session. // +// SessionMgr runs on the workers. +// // SessionMgr is threadsafe. class SessionMgr { public: @@ -39,7 +41,6 @@ class SessionMgr { explicit SessionMgr( WorkerEnv* worker_env, const string& default_worker_name, std::unique_ptr default_worker_cache, - std::unique_ptr default_rendezvous_mgr, WorkerCacheFactory worker_cache_factory); ~SessionMgr() {} @@ -50,49 +51,36 @@ class SessionMgr { WorkerSession* WorkerSessionForSession(const string& session); WorkerSession* LegacySession(); - // Locates the worker session for a given graph handle - WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle); - void AssociateGraphWithSession(const string& session, - const string& graph_handle); - void DisassociateGraphFromSession(const string& graph_handle); - - // Locates a worker session for a given step id - WorkerSession* WorkerSessionForStepId(const int64 step_id); - void AssociateStepIdWithGraph(const string& graph_handle, - const int64 step_id); - void DisassociateStepIdFromGraph(const int64 step_id); - Status DeleteSession(const string& session); static string WorkerNameFromServerDef(const ServerDef& server_def); private: - // Private constructor to work around std::unique_ptr ownership issues. - explicit SessionMgr( - WorkerEnv* worker_env, const string& default_worker_name, - std::unique_ptr default_worker_cache, - RendezvousMgrInterface* default_rendezvous_mgr, - WorkerCacheFactory worker_cache_factory); - const WorkerEnv* const worker_env_; // Not owned. + + // A note about destruction: + // We must delete graph_mgr before device_mgr, due to shared + // ownership of OpKernels in the executors. (The graph_mgr will + // free all stateless OpKernels, and pass over borrowed stateful + // OpKernels, which are also held in their respective devices' + // OpSegments.) + // + // legacy_session_ owns the worker_env_.device_mgr, and so we must ensure + // that sessions_'s WorkerSessions are deleted (which do not own the + // underlying devices, but instead own RenamedDevices) before + // legacy_session_ is deleted. Further, we must ensure that WorkerSession's + // device_mgr is deleted after WorkerSession's graph_mgr. + WorkerSession legacy_session_; const WorkerCacheFactory worker_cache_factory_; WorkerSession* WorkerSessionForSessionUnlocked(const string& session) EXCLUSIVE_LOCKS_REQUIRED(mu_); - WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle) - EXCLUSIVE_LOCKS_REQUIRED(mu_); mutex mu_; // A map from session identifier to internal session structure. std::map> sessions_ GUARDED_BY(mu_); - - // A map from graph handles to the session that they belong to. - std::map sessions_by_graph_handle_ GUARDED_BY(mu_); - - // A map from globally-unique step id's to the corresponding graph handles. - std::map graphs_by_step_id_ GUARDED_BY(mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc index d3f3fa83958..7132f123a59 100644 --- a/tensorflow/core/distributed_runtime/session_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc @@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test { SessionMgrTest() : mgr_(&env_, "/job:mnist/replica:0/task:0", std::unique_ptr(), - std::unique_ptr(new RpcRendezvousMgr( - &env_, "/job:mnist/replica:0/task:0", nullptr)), factory_), legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {} @@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) { TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - + EXPECT_NE(mgr_.LegacySession(), session); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } -TEST_F(SessionMgrTest, AssociateGraphWithSession) { +TEST_F(SessionMgrTest, LegacySession) { ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); + string session_handle = ""; WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(session, graph_session); + EXPECT_EQ(mgr_.LegacySession(), session); TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); } -TEST_F(SessionMgrTest, AssociateStepWithGraph) { - ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); - WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(session, graph_session); - - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(session, step_session); - ASSERT_EQ(graph_session, step_session); - - TF_EXPECT_OK(mgr_.DeleteSession(session_handle)); -} - -TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(legacy_session_, graph_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) { - ServerDef server_def; - string session_handle = "test_session_handle"; - TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def)); - WorkerSession* session = mgr_.WorkerSessionForSession(session_handle); - ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null"; - - string graph_handle = "test_graph_handle"; - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - mgr_.AssociateGraphWithSession(session_handle, graph_handle); - WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle); - ASSERT_EQ(legacy_session_, graph_session); - - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - -TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) { - string session_handle = "test_session_handle"; - string graph_handle = "test_graph_handle"; - int64 step_id = 1234567890L; - mgr_.AssociateStepIdWithGraph(graph_handle, step_id); - WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id); - ASSERT_EQ(legacy_session_, step_session); -} - TEST_F(SessionMgrTest, WorkerNameFromServerDef) { ServerDef server_def; server_def.set_job_name("worker"); diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 89639e21b5d..07bb17981d3 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, Status s = session->graph_mgr->Register( request->session_handle(), request->graph_def(), request->graph_options(), request->debug_options(), response->mutable_graph_handle()); - if (s.ok()) { - env_->session_mgr->AssociateGraphWithSession(request->session_handle(), - response->graph_handle()); - } done(s); } @@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, DeregisterGraphResponse* response, StatusCallback done) { WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); Status s = session->graph_mgr->Deregister(request->graph_handle()); - env_->session_mgr->DisassociateGraphFromSession(request->graph_handle()); done(s); } @@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id, } void Worker::AbortStep(int64 step_id) { - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); - Rendezvous* rendez = session->rendezvous_mgr->Find(step_id); + Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id); SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { // Delay a bit before aborting the step. This way, the root // cause may return first back to the client instead of this @@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, const int64 step_id = request->step_id(); TRACEPRINTF("RunGraph: %lld", step_id); WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle()); - env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; Status s = PrepareRunGraph(request, &in, out); @@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, } CostGraphDef* cost_graph = response->mutable_cost_graph(); session->graph_mgr->ExecuteAsync( - request->graph_handle(), step_id, request->exec_opts(), collector, - cost_graph, cm, in, + request->graph_handle(), step_id, session, request->exec_opts(), + collector, cost_graph, cm, in, [this, step_id, response, session, cm, out, token, collector, opts, done](Status s) { if (s.ok()) { @@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, const string& graph_handle = request->graph_handle(); TRACEPRINTF("PartialRunGraph: %lld", step_id); WorkerSession* session = - env_->session_mgr->WorkerSessionForGraphHandle(graph_handle); - env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id); + env_->session_mgr->WorkerSessionForSession(request->session_handle()); + GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; Status s = PrepareRunGraph(request, &in, out); @@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, [cm]() { cm->StartCancel(); }); } session->graph_mgr->ExecuteAsync( - graph_handle, step_id, request->exec_opts(), nullptr /* collector */, - nullptr /* cost_graph */, cm, in, + graph_handle, step_id, session, request->exec_opts(), + nullptr /* collector */, nullptr /* cost_graph */, cm, in, [this, token, graph_handle, step_id, cm](Status s) { { mutex_lock l(mu_); @@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, CleanupGraphResponse* response, StatusCallback done) { const int64 step_id = request->step_id(); - WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id); - session->rendezvous_mgr->Cleanup(step_id); + env_->rendezvous_mgr->Cleanup(step_id); done(Status::OK()); } @@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request, Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, Device** src_dev) { // Figures out which device the tensor is hosted on. - TF_RETURN_IF_ERROR( - env_->device_mgr->LookupDevice(parsed.src_device, src_dev)); + string local_name = DeviceNameUtils::LocalName(parsed.src_device); + TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev)); // Does the device have the right incarnation number we expect? if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index 24fb5948a71..f09bea328fd 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ +#include #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -24,8 +25,10 @@ namespace thread { class ThreadPool; } // namespace thread +class Device; class DeviceMgr; class Env; +class RendezvousMgrInterface; class SessionMgr; // The worker environment class, which holds a bag of pointers to @@ -38,10 +41,18 @@ struct WorkerEnv { // session_mgr encapsulates state for each session. SessionMgr* session_mgr = nullptr; + // The local devices of this worker. Devices are owned by the device_mgr. + // + // REQUIRES: !local_devices.empty(). + std::vector local_devices; + // device_mgr manages local devices (cpu and gpu). The WorkerService // is the network interface for managed devices. DeviceMgr* device_mgr = nullptr; + // A set of rendezvous keyed by step ids. + RendezvousMgrInterface* rendezvous_mgr = nullptr; + // A pool of threads for scheduling compute work. thread::ThreadPool* compute_pool = nullptr; }; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index 508bc7f4680..c9db28ec67f 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -113,6 +113,11 @@ class WorkerInterface { return CallAndWait(&ME::GetStatusAsync, request, response); } + Status CreateWorkerSession(const CreateWorkerSessionRequest* request, + CreateWorkerSessionResponse* response) { + return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); + } + Status RegisterGraph(const RegisterGraphRequest* request, RegisterGraphResponse* response) { return CallAndWait(&ME::RegisterGraphAsync, request, response); diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 8298e169595..8691450e9bc 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -17,14 +17,84 @@ limitations under the License. namespace tensorflow { -WorkerSession::WorkerSession( - const string& worker_name, - std::unique_ptr worker_cache, - std::unique_ptr rendezvous_mgr, - std::unique_ptr graph_mgr) +namespace { + +// A private cache that wraps worker_cache and allows reuse of +// WorkerInterface objects. +class WorkerFreeListCache : public WorkerCacheInterface { + public: + explicit WorkerFreeListCache(std::unique_ptr w) + : wrapped_(std::move(w)) {} + + ~WorkerFreeListCache() final { + for (auto p : workers_) { + wrapped_->ReleaseWorker(p.first, p.second.worker); + } + } + + void ListWorkers(std::vector* workers) const override { + wrapped_->ListWorkers(workers); + } + + WorkerInterface* CreateWorker(const string& target) override { + mutex_lock l(mu_); + auto p = workers_.find(target); + if (p != workers_.end()) { + return p->second.worker; + } + WorkerState state; + state.worker = wrapped_->CreateWorker(target); + if (state.worker != nullptr) { + workers_.insert(std::make_pair(target, state)); + } + return state.worker; + } + + void ReleaseWorker(const string& target, WorkerInterface* worker) override { + // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. + } + + bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) override { + return wrapped_->GetDeviceLocalityNonBlocking(device, locality); + } + + void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, + StatusCallback done) override { + wrapped_->GetDeviceLocalityAsync(device, locality, done); + } + + void SetLogging(bool active) override { wrapped_->SetLogging(active); } + + void ClearLogs() override { wrapped_->ClearLogs(); } + + bool RetrieveLogs(int64 step_id, StepStats* ss) override { + return wrapped_->RetrieveLogs(step_id, ss); + } + + private: + std::unique_ptr wrapped_; + + // Information kept per created WorkerInterface. + struct WorkerState { + WorkerInterface* worker; + // TODO(jeff,sanjay): Add reference count if we support eviction. + }; + + // TODO(jeff,sanjay): Eviction when the map becomes too big. + mutex mu_; + std::unordered_map workers_ GUARDED_BY(mu_); +}; + +} // namespace + +WorkerSession::WorkerSession(const string& worker_name, + std::unique_ptr worker_cache, + std::unique_ptr device_mgr, + std::unique_ptr graph_mgr) : worker_name(worker_name), - worker_cache(std::move(worker_cache)), - rendezvous_mgr(std::move(rendezvous_mgr)), + worker_cache(new WorkerFreeListCache(std::move(worker_cache))), + device_mgr(std::move(device_mgr)), graph_mgr(std::move(graph_mgr)) {} } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index e6ebe883298..77cf4de8f74 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -18,14 +18,13 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" -#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" namespace tensorflow { class GraphMgr; -class RendezvousMgrInterface; class WorkerCacheInterface; // WorkerSession encapsulates all of the state relating to a given session. @@ -36,17 +35,20 @@ struct WorkerSession { // Object from which WorkerInterface instances can be obtained. const std::unique_ptr worker_cache; - // A set of rendezvous keyed by step ids. - const std::unique_ptr rendezvous_mgr; + // Collection of local devices. These devices are typically RenamedDevices + // in all except the SessionMgr.legacy_session_. legacy_session_.device_mgr + // == worker_env_.device_mgr, which holds the true devices. + const std::unique_ptr device_mgr; // graph_mgr keeps track of the registered graphs of this session. // // Note: graph_mgr must be deleted before rendezvous_mgr! + // Note: graph_mgr must be deleted before device_mgr! const std::unique_ptr graph_mgr; WorkerSession(const string& worker_name, std::unique_ptr worker_cache, - std::unique_ptr rendezvous_mgr, + std::unique_ptr device_mgr, std::unique_ptr graph_mgr); }; diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 8894671fdf3..27fe28fe60a 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -115,7 +115,7 @@ class DeviceBase { cpu_worker_threads_ = t; } - const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { + virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { CHECK(cpu_worker_threads_ != nullptr); return cpu_worker_threads_; } @@ -140,7 +140,7 @@ class DeviceBase { gpu_device_info_ = g; } - const GpuDeviceInfo* tensorflow_gpu_device_info() const { + virtual const GpuDeviceInfo* tensorflow_gpu_device_info() const { return gpu_device_info_; } @@ -170,13 +170,13 @@ class DeviceBase { return GetAllocator(attr); } - const Eigen::ThreadPoolDevice* eigen_cpu_device() { + virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() { CHECK(eigen_cpu_device_ != nullptr); return eigen_cpu_device_; } #ifdef TENSORFLOW_USE_SYCL - const Eigen::SyclDevice* eigen_sycl_device() const { + virtual const Eigen::SyclDevice* eigen_sycl_device() const { CHECK(eigen_sycl_device_ != nullptr); return eigen_sycl_device_; } diff --git a/tensorflow/core/protobuf/cluster.proto b/tensorflow/core/protobuf/cluster.proto new file mode 100644 index 00000000000..33c87eefe02 --- /dev/null +++ b/tensorflow/core/protobuf/cluster.proto @@ -0,0 +1,82 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ClusterProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +// This file contains protos to be used when defining a TensorFlow +// cluster. +// +// EXAMPLES +// -------- +// +// 1. A single-process cluster, containing "/job:local/task:0". +// +// Cluster: +// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } +// +// Server: +// cluster { $CLUSTER } job_name: 'local' task_index: 0 +// +// 2. A two-process cluster, containing "/job:local/task:{0,1}". +// +// Cluster: +// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } +// tasks { key: 1 value: 'localhost:2223' } } +// +// Servers: +// cluster { $CLUSTER } job_name: 'local' task_index: 0 +// cluster { $CLUSTER } job_name: 'local' task_index: 1 +// +// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and +// "/job:ps/task:{0,1}". +// +// Cluster: +// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } +// tasks { key: 1 value: 'worker2:2222' } +// tasks { key: 2 value: 'worker3:2222' } } +// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } +// tasks { key: 1 value: 'ps1:2222' } } +// +// Servers: +// cluster { $CLUSTER } job_name: 'worker' task_index: 0 +// cluster { $CLUSTER } job_name: 'worker' task_index: 1 +// cluster { $CLUSTER } job_name: 'worker' task_index: 2 +// cluster { $CLUSTER } job_name: 'ps' task_index: 0 +// cluster { $CLUSTER } job_name: 'ps' task_index: 1 + +// Defines a single job in a TensorFlow cluster. +message JobDef { + // The name of this job. + string name = 1; + + // Mapping from task ID to "hostname:port" string. + // + // If the `name` field contains "worker", and the `tasks` map contains a + // mapping from 7 to "example.org:2222", then the device prefix + // "/job:worker/task:7" will be assigned to "example.org:2222". + map tasks = 2; +} + +// Defines a TensorFlow cluster as a set of jobs. +message ClusterDef { + // The jobs that comprise the cluster. + repeated JobDef job = 1; +} diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 5c0f7232ebd..630f47633f8 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -10,6 +10,7 @@ import "tensorflow/core/framework/cost_graph.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/step_stats.proto"; import "tensorflow/core/protobuf/debug.proto"; +import "tensorflow/core/protobuf/cluster.proto"; import "tensorflow/core/protobuf/rewriter_config.proto"; message GPUOptions { @@ -259,6 +260,11 @@ message ConfigProto { // Options that apply when this session uses the distributed runtime. RPCOptions rpc_options = 13; + + // Optional list of all workers to use in this session. + ClusterDef cluster_def = 14; + + // Next: 15 }; // Options for a single Run() call. diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto index de91b6133e4..e607b1c42a5 100644 --- a/tensorflow/core/protobuf/master.proto +++ b/tensorflow/core/protobuf/master.proto @@ -38,6 +38,9 @@ message CreateSessionRequest { // Configuration options. ConfigProto config = 2; + + // The target string used from the client's perspective. + string target = 3; } message CreateSessionResponse { diff --git a/tensorflow/core/protobuf/tensorflow_server.proto b/tensorflow/core/protobuf/tensorflow_server.proto index c4077bd98e4..6199e707e5a 100644 --- a/tensorflow/core/protobuf/tensorflow_server.proto +++ b/tensorflow/core/protobuf/tensorflow_server.proto @@ -16,6 +16,7 @@ limitations under the License. syntax = "proto3"; import "tensorflow/core/protobuf/config.proto"; +import "tensorflow/core/protobuf/cluster.proto"; package tensorflow; option cc_enable_arenas = true; @@ -23,69 +24,6 @@ option java_outer_classname = "ServerProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.distruntime"; -// This file contains protos to be used when defining a TensorFlow -// cluster, and a server within that cluster. -// -// EXAMPLES -// -------- -// -// 1. A single-process cluster, containing "/job:local/task:0". -// -// Cluster: -// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } -// -// Server: -// cluster { $CLUSTER } job_name: 'local' task_index: 0 -// -// 2. A two-process cluster, containing "/job:local/task:{0,1}". -// -// Cluster: -// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } -// tasks { key: 1 value: 'localhost:2223' } } -// -// Servers: -// cluster { $CLUSTER } job_name: 'local' task_index: 0 -// cluster { $CLUSTER } job_name: 'local' task_index: 1 -// -// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and -// "/job:ps/task:{0,1}". -// -// Cluster: -// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } -// tasks { key: 1 value: 'worker2:2222' } -// tasks { key: 2 value: 'worker3:2222' } } -// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } -// tasks { key: 1 value: 'ps1:2222' } } -// -// Servers: -// cluster { $CLUSTER } job_name: 'worker' task_index: 0 -// cluster { $CLUSTER } job_name: 'worker' task_index: 1 -// cluster { $CLUSTER } job_name: 'worker' task_index: 2 -// cluster { $CLUSTER } job_name: 'ps' task_index: 0 -// cluster { $CLUSTER } job_name: 'ps' task_index: 1 - -// Defines a single job in a TensorFlow cluster. -message JobDef { - // The name of this job. - string name = 1; - - // Mapping from task ID to "hostname:port" string. - // - // If the `name` field contains "worker", and the `tasks` map contains a - // mapping from 7 to "example.org:2222", then the device prefix - // "/job:worker/task:7" will be assigned to "example.org:2222". - // - // NOTE(mrry): Currently, only a dense task ID space starting at 0 is - // supported. - map tasks = 2; -} - -// Defines a TensorFlow cluster as a set of jobs. -message ClusterDef { - // The jobs that comprise the cluster. - repeated JobDef job = 1; -} - // Defines the configuration of a single TensorFlow server. message ServerDef { // The cluster of which this server is a member. diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 661327847c1..cf05aece39a 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -119,6 +119,10 @@ message RegisterGraphResponse { //////////////////////////////////////////////////////////////////////////////// message DeregisterGraphRequest { + // The session_handle used when registering the graph. If session_handle is + // empty, a single global namespace is used. + string session_handle = 2; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. string graph_handle = 1; @@ -167,6 +171,12 @@ message ExecutorOpts { }; message RunGraphRequest { + // session_handle is the the master-generated unique id for this session. + // If session_handle is non-empty, it must be the same as used when + // registering the graph. If it is empty, a single global namespace is used to + // search for the graph_handle. + string session_handle = 8; + // REQUIRED: graph_handle must be returned by a RegisterGraph call // to the same WorkerService. string graph_handle = 1; @@ -193,6 +203,8 @@ message RunGraphRequest { bool is_partial = 6; // True if this is the last partial run request in a sequence of requests. bool is_last_partial_run = 7; + + // Next: 9 } message RunGraphResponse { diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 864a96ef348..6336ca23105 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -55,6 +55,7 @@ from tensorflow.core.framework.summary_pb2 import * from tensorflow.core.framework.attr_value_pb2 import * from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo from tensorflow.core.protobuf.config_pb2 import * +from tensorflow.core.protobuf.tensorflow_server_pb2 import * from tensorflow.core.protobuf.rewriter_config_pb2 import * from tensorflow.core.util.event_pb2 import * @@ -131,6 +132,7 @@ _allowed_symbols = [ 'AttrValue', 'AutoParallelOptions', 'ConfigProto', + 'ClusterDef', 'DeviceSpec', 'Event', 'GPUOptions', diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 9add5bd3cde..040cc333158 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -29,6 +29,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes @@ -1789,7 +1790,7 @@ class SessionTest(test_util.TensorFlowTestCase): with CaptureStderr() as log: sess.run(c) # Ensure that we did log device placement. - self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log)) + self.assertTrue('/job:local/replica:0/task:0/cpu:0' in str(log), str(log)) def testLocalMasterSessionTimeout(self): # Test that the timeout passed in a config to the session works correctly. @@ -1834,6 +1835,270 @@ class SessionTest(test_util.TensorFlowTestCase): server = server_lib.Server.create_local_server() self.runTestBuildGraphError(session.Session(server.target)) + def testClusterSpecPropagationSimple(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config) + output = sess.run(const) + self.assertEqual(17, output) + + def testClusterSpecPropagationWorker2Placement(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g, ops.device('/job:worker/task:1'): + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config, graph=g) + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + output = sess.run(const, options=run_options, run_metadata=run_metadata) + self.assertEqual(17, output) + self.assertEqual(1, + len([ + node_stats + for dev_stats in run_metadata.step_stats.dev_stats + for node_stats in dev_stats.node_stats + if '/job:worker/replica:0/task:1/device:CPU:0' == + dev_stats.device and 'Const' == node_stats.node_name + ])) + + def testClusterSpecPropagationWorker1Placement(self): + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.Graph().as_default() as g, ops.device('/job:worker/task:0'): + const = constant_op.constant(17) + sess = session.Session(server1.target, config=config, graph=g) + output = sess.run(const) + self.assertEqual(17, output) + + def testClusterSpecPropagationThreeServers2Graphs(self): + """Boots 3 servers, creates 2 sessions, ensures appropriate operations. + + We create 2 clusterspecs: + 1. server2 as the master, server1 as a worker + 2. server2 as the master, server3 as a worker + + We ensure that variables on the workers are independent. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def1 = cluster_pb2.ClusterDef() + job1 = cluster_def1.job.add() + job1.name = 'worker1' + job1.tasks[0] = server2.target[len('grpc://'):] + job1.tasks[1] = server1.target[len('grpc://'):] + + cluster_def2 = cluster_pb2.ClusterDef() + job2 = cluster_def2.job.add() + job2.name = 'worker2' + job2.tasks[0] = server2.target[len('grpc://'):] + job2.tasks[1] = server3.target[len('grpc://'):] + + config1 = config_pb2.ConfigProto(cluster_def=cluster_def1) + config2 = config_pb2.ConfigProto(cluster_def=cluster_def2) + + with ops.Graph().as_default() as g1: + with ops.device('/job:worker1/task:1'): + var1 = variables.Variable(array_ops.zeros([2]), name='var1') + update_op1 = state_ops.assign_add( + var1, array_ops.ones([2]), name='var1_assign_add') + init1 = variables.global_variables_initializer() + + with ops.Graph().as_default() as g2: + with ops.device('/job:worker2/task:1'): + var2 = variables.Variable(array_ops.zeros([2]), name='var2') + update_op2 = state_ops.assign_add( + var2, array_ops.ones([2]), name='var2_assign_add') + init2 = variables.global_variables_initializer() + + sess1 = session.Session(server2.target, graph=g1, config=config1) + sess2 = session.Session(server2.target, graph=g2, config=config2) + + init1.run(session=sess1) + init2.run(session=sess2) + + expected_zeros = np.zeros([2]) + expected_ones = np.ones([2]) + + self.assertAllEqual(expected_zeros, sess1.run(var1)) + self.assertAllEqual(expected_zeros, sess2.run(var2)) + + self.assertAllEqual(expected_ones, sess1.run(update_op1)) + self.assertAllEqual(expected_ones, sess1.run(var1)) + self.assertAllEqual(expected_zeros, sess2.run(var2)) + self.assertAllEqual(expected_ones, sess2.run(update_op2)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(update_op1)) + self.assertAllEqual(expected_ones, sess2.run(var2)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(var1)) + + def testClusterSpecPropagationThreeServers(self): + """Boots 3 servers, creates 2 sessions, ensures appropriate operations. + + We create 2 clusterspecs: + 1. server2 as the master, server1 as a worker + 2. server2 as the master, server3 as a worker + + We ensure that variables on the workers are independent. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def1 = cluster_pb2.ClusterDef() + job1 = cluster_def1.job.add() + job1.name = 'worker' + job1.tasks[0] = server2.target[len('grpc://'):] + job1.tasks[1] = server1.target[len('grpc://'):] + + cluster_def2 = cluster_pb2.ClusterDef() + job2 = cluster_def2.job.add() + job2.name = 'worker' + job2.tasks[0] = server2.target[len('grpc://'):] + job2.tasks[1] = server3.target[len('grpc://'):] + + config1 = config_pb2.ConfigProto(cluster_def=cluster_def1) + config2 = config_pb2.ConfigProto(cluster_def=cluster_def2) + + with ops.device('/job:worker/task:1'): + var = variables.Variable(array_ops.zeros([2]), name='var') + feed = array_ops.placeholder(dtypes.float32, shape=(2)) + update_op = var.assign_add(feed) + + sess1 = session.Session(server2.target, config=config1) + sess2 = session.Session(server2.target, config=config2) + + variables.global_variables_initializer().run(session=sess1) + variables.global_variables_initializer().run(session=sess2) + + expected_zeros = np.zeros([2]) + expected_ones = np.ones([2]) + + self.assertAllEqual(expected_zeros, sess1.run(var)) + self.assertAllEqual(expected_zeros, sess2.run(var)) + self.assertAllEqual(expected_ones, + sess1.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones, sess1.run(var)) + self.assertAllEqual(expected_zeros, sess2.run(var)) + self.assertAllEqual(expected_ones, + sess2.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones + expected_ones, + sess1.run(update_op, feed_dict={feed: expected_ones})) + self.assertAllEqual(expected_ones, sess2.run(var)) + self.assertAllEqual(expected_ones + expected_ones, sess1.run(var)) + + def testClusterSpecPropagationThreeServersOneCluster(self): + """Boots 3 servers, ensures appropriate communication across workers. + + Additionally, in this cluster, we ensure the master is not the 0-th worker. + + Note: this test only uses one session. + """ + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + server3 = server_lib.Server.create_local_server() + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server3.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + job.tasks[2] = server1.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + # Add ops to the devices in non-linear order. + + with ops.device('/job:worker/task:1'): + feed1 = array_ops.placeholder(dtypes.float32, shape=(2)) + const1 = constant_op.constant(2.0) + mul1 = const1 * feed1 + + with ops.device('/job:worker/task:2'): + feed2 = array_ops.placeholder(dtypes.float32, shape=(2)) + const2 = constant_op.constant(2.0) + mul2 = const2 * feed2 + + with ops.device('/job:worker/task:0'): + feed0 = array_ops.placeholder(dtypes.float32, shape=(2)) + const0 = constant_op.constant(2.0) + mul0 = const0 * feed0 + + sum_op = mul0 + mul1 + mul2 + + ones = np.ones([2]) + run_options = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + + # Run! + with session.Session(server1.target, config=config) as sess: + output = sess.run( + sum_op, + options=run_options, + run_metadata=run_metadata, + feed_dict={feed1: ones, + feed2: ones, + feed0: ones}) + self.assertAllEqual(6 * ones, output) + + self.assertEqual( + 3, + len([ + dev_stats.device + for dev_stats in run_metadata.step_stats.dev_stats + for node_stats in dev_stats.node_stats + if '/job:worker/replica:0/task:' in dev_stats.device and + node_stats.node_name.startswith('Const') + ]), run_metadata) + + def testClusterSpecPropagationPartialRun(self): + """Test successful partial run with ClusterSpec propagation.""" + server1 = server_lib.Server.create_local_server() + server2 = server_lib.Server.create_local_server() + + cluster_def = cluster_pb2.ClusterDef() + job = cluster_def.job.add() + job.name = 'worker' + job.tasks[0] = server1.target[len('grpc://'):] + job.tasks[1] = server2.target[len('grpc://'):] + config = config_pb2.ConfigProto(cluster_def=cluster_def) + + with ops.device('/job:worker/task:0'): + a = array_ops.placeholder(dtypes.float32, shape=[]) + with ops.device('/job:worker/task:1'): + b = array_ops.placeholder(dtypes.float32, shape=[]) + c = array_ops.placeholder(dtypes.float32, shape=[]) + r1 = math_ops.add(a, b) + with ops.device('/job:worker/task:0'): + r2 = math_ops.multiply(r1, c) + + with session.Session(server1.target, config=config) as sess: + h = sess.partial_run_setup([r1, r2], [a, b, c]) + res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) + self.assertEqual(3, res) + res = sess.partial_run(h, r2, feed_dict={c: 3}) + self.assertEqual(9, res) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index d2ccf37d885..2091eca0b9c 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import errors @@ -276,14 +277,14 @@ class ClusterSpec(object): "from integers to strings." % job_name) self._cluster_spec[job_name] = job_tasks self._make_cluster_def() - elif isinstance(cluster, tensorflow_server_pb2.ClusterDef): + elif isinstance(cluster, cluster_pb2.ClusterDef): self._cluster_def = cluster self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = { i: t for i, t in job_def.tasks.items()} elif isinstance(cluster, ClusterSpec): - self._cluster_def = tensorflow_server_pb2.ClusterDef() + self._cluster_def = cluster_pb2.ClusterDef() self._cluster_def.MergeFrom(cluster.as_cluster_def()) self._cluster_spec = {} for job_def in self._cluster_def.job: @@ -440,7 +441,7 @@ class ClusterSpec(object): TypeError: If `cluster_spec` is not a dictionary mapping strings to lists of strings. """ - self._cluster_def = tensorflow_server_pb2.ClusterDef() + self._cluster_def = cluster_pb2.ClusterDef() # NOTE(mrry): Sort by job_name to produce deterministic protobufs. for job_name, tasks in sorted(self._cluster_spec.items()): diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index bdf3d9c0175..f4ac3c97587 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -186,8 +186,8 @@ from tensorflow.python.training.learning_rate_decay import * # pylint: enable=wildcard-import # Distributed computing support. -from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef -from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef +from tensorflow.core.protobuf.cluster_pb2 import JobDef from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef from tensorflow.python.training.server_lib import ClusterSpec from tensorflow.python.training.server_lib import Server @@ -196,32 +196,32 @@ from tensorflow.python.training.server_lib import Server _allowed_symbols = [ # TODO(cwhipkey): review these and move to contrib or expose through # documentation. - "generate_checkpoint_state_proto", # Used internally by saver. + "generate_checkpoint_state_proto", # Used internally by saver. "checkpoint_exists", # Only used in test? "get_checkpoint_mtimes", # Only used in test? # Legacy: remove. "do_quantize_training_on_graphdef", # At least use grah_def, not graphdef. - # No uses within tensorflow. + # No uses within tensorflow. "queue_runner", # Use tf.train.start_queue_runner etc directly. - # This is also imported internally. + # This is also imported internally. # TODO(drpng): document these. The reference in howtos/distributed does # not link. "SyncReplicasOptimizer", # Protobufs: - "BytesList", # from example_pb2. + "BytesList", # from example_pb2. "ClusterDef", - "Example", # from example_pb2 - "Feature", # from example_pb2 - "Features", # from example_pb2 - "FeatureList", # from example_pb2 - "FeatureLists", # from example_pb2 - "FloatList", # from example_pb2. - "Int64List", # from example_pb2. + "Example", # from example_pb2 + "Feature", # from example_pb2 + "Features", # from example_pb2 + "FeatureList", # from example_pb2 + "FeatureLists", # from example_pb2 + "FloatList", # from example_pb2. + "Int64List", # from example_pb2. "JobDef", - "SaverDef", # From saver_pb2. - "SequenceExample", # from example_pb2. + "SaverDef", # From saver_pb2. + "SequenceExample", # from example_pb2. "ServerDef", ] # Include extra modules for docstrings because: diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt index 805a9bdd4f1..da6af3919e9 100644 --- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt @@ -6,6 +6,10 @@ tf_class { name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER" mtype: "" } + member { + name: "CLUSTER_DEF_FIELD_NUMBER" + mtype: "" + } member { name: "DESCRIPTOR" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt index feb73bd7d4f..93ff856b09d 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.ClusterDef" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt index 2d7fcbe5456..ac6d81541a4 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.JobDef.TasksEntry" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt index fc5b76341d2..ce34537fa13 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.train.JobDef" tf_class { - is_instance: "" + is_instance: "" is_instance: "" member { name: "DESCRIPTOR"