From 9b9e5989d247d274c4137db533e43b95d825acfc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Aug 2017 17:03:26 -0700 Subject: [PATCH 01/70] Add a call_logit_fn utility for logit_fn's, similar to Estimator's _call_model_fn. PiperOrigin-RevId: 165649388 --- tensorflow/contrib/learn/BUILD | 12 ++++ .../learn/python/learn/estimators/__init__.py | 1 + .../python/learn/estimators/logit_fns.py | 39 +++++++++++- .../python/learn/estimators/logit_fns_test.py | 60 +++++++++++++++++++ 4 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/learn/python/learn/estimators/logit_fns_test.py diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 978ebfef77f..c2e74d1cc2e 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -119,6 +119,18 @@ py_test( ], ) +py_test( + name = "logit_fns_test", + size = "small", + srcs = ["python/learn/estimators/logit_fns_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow/python:client_testlib", + "//tensorflow/python/estimator:model_fn", + ], +) + py_test( name = "estimators_test", size = "small", diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 42943fdd3ac..9d63d7dcd0b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -321,6 +321,7 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassi from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor +from tensorflow.contrib.learn.python.learn.estimators.logit_fns import call_logit_fn from tensorflow.contrib.learn.python.learn.estimators.logit_fns import dnn_logit_fn_builder from tensorflow.contrib.learn.python.learn.estimators.logit_fns import linear_logit_fn_builder from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey diff --git a/tensorflow/contrib/learn/python/learn/estimators/logit_fns.py b/tensorflow/contrib/learn/python/learn/estimators/logit_fns.py index f04a47b29af..110ea0302e7 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logit_fns.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logit_fns.py @@ -21,7 +21,7 @@ should follow the following signature: Args: `features`: This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single - `Tensor` or `dict` of same. + `Tensor` or `dict` of same, and is the only required argument. `mode`: Optional. Specifies if this training, evaluation or prediction. See `ModeKeys`. `params`: Optional `dict` of hyperparameters. Will receive what is passed to @@ -39,10 +39,47 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core +from tensorflow.python.framework import ops # pylint: disable=protected-access dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder linear_logit_fn_builder = linear_core._linear_logit_fn_builder # pylint: enable=protected-access + + +def call_logit_fn(logit_fn, features, mode, params, config): + """Calls logit_fn. + + A utility function that calls the provided logit_fn with the relevant subset + of provided arguments. Similar to tf.estimator._call_model_fn(). + + Args: + logit_fn: A logit_fn as defined above. + features: The features dict. + mode: TRAIN / EVAL / PREDICT ModeKeys. + params: The hyperparameter dict. + config: The configuration object. + + Returns: + A logit Tensor, the output of logit_fn. + + Raises: + ValueError: if logit_fn does not return a Tensor. + """ + logit_fn_args = util.fn_args(logit_fn) + kwargs = {} + if 'mode' in logit_fn_args: + kwargs['mode'] = mode + if 'params' in logit_fn_args: + kwargs['params'] = params + if 'config' in logit_fn_args: + kwargs['config'] = config + logit_fn_results = logit_fn(features=features, **kwargs) + + if not isinstance(logit_fn_results, ops.Tensor): + raise ValueError('model_fn should return a Tensor.') + + return logit_fn_results diff --git a/tensorflow/contrib/learn/python/learn/estimators/logit_fns_test.py b/tensorflow/contrib/learn/python/learn/estimators/logit_fns_test.py new file mode 100644 index 00000000000..01616d1a7ff --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/logit_fns_test.py @@ -0,0 +1,60 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""logit_fn tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.estimators import logit_fns +from tensorflow.python.client import session +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class LogitFnTest(test.TestCase): + + def test_simple_call_logit_fn(self): + def dummy_logit_fn(features, mode): + if mode == model_fn.ModeKeys.TRAIN: + return features['f1'] + else: + return features['f2'] + features = { + 'f1': constant_op.constant([2., 3.]), + 'f2': constant_op.constant([4., 5.]) + } + logit_fn_result = logit_fns.call_logit_fn( + dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params', + 'fake_config') + with session.Session(): + self.assertAllClose([[4., 5.]], logit_fn_result.eval()) + + def test_should_return_tensor(self): + + def invalid_logit_fn(features, params): + return { + 'tensor1': features['f1'] * params['input_multiplier'], + 'tensor2': features['f2'] * params['input_multiplier'] + } + features = { + 'f1': constant_op.constant([2., 3.]), + 'f2': constant_op.constant([4., 5.]) + } + params = {'learning_rate': 0.001, 'input_multiplier': 2.0} + with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'): + logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params, + 'fake_config') From a3c4e980e00e9c332a4e9f8c232fb2a1cc2f5694 Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Thu, 17 Aug 2017 17:05:12 -0700 Subject: [PATCH 02/70] Fixed input shape for freezing audio graphs PiperOrigin-RevId: 165649546 --- tensorflow/examples/speech_commands/freeze.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py index 381f3d029e5..6d2f2102625 100644 --- a/tensorflow/examples/speech_commands/freeze.py +++ b/tensorflow/examples/speech_commands/freeze.py @@ -90,9 +90,14 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, spectrogram, decoded_sample_data.sample_rate, dct_coefficient_count=dct_coefficient_count) + fingerprint_frequency_size = model_settings['dct_coefficient_count'] + fingerprint_time_size = model_settings['spectrogram_length'] + reshaped_input = tf.reshape(fingerprint_input, [ + -1, fingerprint_time_size * fingerprint_frequency_size + ]) logits = models.create_model( - fingerprint_input, model_settings, model_architecture, is_training=False) + reshaped_input, model_settings, model_architecture, is_training=False) # Create an output to use for inference. tf.nn.softmax(logits, name='labels_softmax') From 8c0853db731cf80cfeec9dfb4edab95961aaa585 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Aug 2017 17:09:29 -0700 Subject: [PATCH 03/70] Add a test for negative and zero pow() input. PiperOrigin-RevId: 165650096 --- .../compiler/xla/tests/array_elementwise_ops_test.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 192477555d0..532e2394c0d 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -785,6 +785,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { + SetFastMathDisabled(true); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.0f, -0.6f, -0.6f, 0.0f}); + auto rhs = builder.ConstantR1({0.5f, 0.6f, -0.6f, -0.6f}); + auto minimum = builder.Pow(lhs, rhs); + + ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { ComputationBuilder builder(client_, TestName()); auto lhs = builder.ConstantR1({}); From 19a55725af8102d72d4e081c5139f0e4bd5a4bb7 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Thu, 17 Aug 2017 17:20:17 -0700 Subject: [PATCH 04/70] Allowing functions to run across devices. This change expands the ProcessFunctionLibraryRuntime library to Instantiate and Run functions on different devices. When a FunctionLibraryRuntime encounters a function with a target that is another device, it delegates Instantiate() and Run() calls to the ProcessFunctionLibraryRuntime. This change also moves the table_ containing all function instantiations to the PFLR instead of the FunctionLibraryRuntime. PiperOrigin-RevId: 165651194 --- tensorflow/c/eager/c_api.cc | 18 +- .../jit/encapsulate_subgraphs_pass.cc | 17 +- .../compiler/jit/mark_for_compilation_pass.cc | 11 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 25 ++- tensorflow/compiler/tf2xla/xla_compiler.h | 8 +- tensorflow/contrib/cmake/tf_tests.cmake | 1 + .../contrib/data/python/kernel_tests/BUILD | 22 ++ .../kernel_tests/iterator_ops_cluster_test.py | 109 ++++++++++ .../python/kernel_tests/iterator_ops_test.py | 58 +++++ tensorflow/core/BUILD | 15 +- tensorflow/core/common_runtime/function.cc | 107 ++++++--- tensorflow/core/common_runtime/function.h | 14 +- .../core/common_runtime/function_test.cc | 203 ++++++++++-------- .../core/common_runtime/function_testlib.cc | 58 +++++ .../core/common_runtime/function_testlib.h | 31 +++ .../process_function_library_runtime.cc | 107 ++++++++- .../process_function_library_runtime.h | 55 ++++- .../process_function_library_runtime_test.cc | 129 ++++++++++- tensorflow/core/framework/function.h | 8 + tensorflow/core/framework/function_testlib.cc | 7 + tensorflow/core/framework/function_testlib.h | 18 ++ .../core/grappler/grappler_item_builder.cc | 11 +- tensorflow/core/kernels/captured_function.cc | 32 +-- tensorflow/core/kernels/captured_function.h | 13 +- tensorflow/core/kernels/function_ops.cc | 62 ++++++ tensorflow/core/ops/functional_ops.cc | 19 ++ .../kernel_tests/functional_ops_test.py | 54 +++++ 27 files changed, 1015 insertions(+), 197 deletions(-) create mode 100644 tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py create mode 100644 tensorflow/core/common_runtime/function_testlib.cc create mode 100644 tensorflow/core/common_runtime/function_testlib.h diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 22d5f233c31..b1baa5ce125 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -64,19 +64,14 @@ struct TFE_Context { // One FunctionLibraryRuntime per device. // func_libs[i] is the FunctionLibraryRuntime corresponding to // session->devices[i]. - std::vector > func_libs; + std::unique_ptr pflr; std::unordered_map kernel_cache; tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { - for (int i = 0; i < session->devices.size(); ++i) { - if (session->devices[i] == d) { - return func_libs[i].get(); - } - } - return nullptr; + return pflr->GetFLR(d->name()); } const std::vector& devices() { return session->devices; } @@ -132,12 +127,9 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { } TFE_Context* ret = new TFE_Context(session); - ret->func_libs.resize(ret->devices().size()); - for (int i = 0; i < ret->devices().size(); ++i) { - ret->func_libs[i] = tensorflow::NewFunctionLibraryRuntime( - ret->session->device_mgr, opts->options.env, ret->devices()[i], - TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}); - } + ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( + ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION, + &ret->func_lib_def, {})); ret->rendezvous = new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index a1ddad3e9b8..22899ebeebc 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -624,15 +624,18 @@ Status EncapsulateSubgraphsPass::Run( FunctionLibraryDefinition* const library = options.flib_def; OptimizerOptions opts; - std::unique_ptr flr( - NewFunctionLibraryRuntime(nullptr, options.session_options->env, nullptr, - TF_GRAPH_DEF_VERSION, library, opts)); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env, + TF_GRAPH_DEF_VERSION, library, opts)); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - auto rewrite_subgraph = [&flr]( - std::unique_ptr* subgraph, std::vector* input_permutation, - std::vector* output_permutation, NodeDef* node) { + auto rewrite_subgraph = [flr](std::unique_ptr* subgraph, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* node) { // Optimize the subgraph. - OptimizeGraph(flr.get(), subgraph); + OptimizeGraph(flr, subgraph); const int num_args = input_permutation->size(); std::vector const_args(num_args); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 77b45aa11e2..2fe190e605f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -176,8 +176,11 @@ Status FindCompilationCandidates( const std::function& is_compilable_fn, std::unordered_set* candidates) { OptimizerOptions opts; - std::unique_ptr lib_runtime(NewFunctionLibraryRuntime( - nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts)); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, + flib_def, opts)); + FunctionLibraryRuntime* lib_runtime = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); for (Node* node : graph.op_nodes()) { DeviceType device_type(""); @@ -191,7 +194,7 @@ Status FindCompilationCandidates( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) { + !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { VLOG(2) << "Compilation rejected node: unsupported op " << node->name() << ": " << node->type_string(); continue; @@ -203,7 +206,7 @@ Status FindCompilationCandidates( continue; } if (node->type_string() == "While" && - !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) { + !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { continue; } candidates->insert(node); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d9bfaa93322..ae13147a18e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -88,15 +88,18 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) } local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), + FunctionDefLibrary{})); - local_flib_runtime_ = NewFunctionLibraryRuntime( - &device_mgr_, Env::Default(), device_, options.graph_def_version, + local_pflr_.reset(new ProcessFunctionLibraryRuntime( + &device_mgr_, Env::Default(), options.graph_def_version, local_flib_def_.get(), OptimizerOptions(), - nullptr /* custom_kernel_creator */); - flib_runtime_ = NewFunctionLibraryRuntime( - &device_mgr_, Env::Default(), device_, options.graph_def_version, - options.flib_def, OptimizerOptions(), - nullptr /* custom_kernel_creator */); + nullptr /* custom_kernel_creator */)); + pflr_.reset(new ProcessFunctionLibraryRuntime( + &device_mgr_, Env::Default(), options.graph_def_version, options.flib_def, + OptimizerOptions(), nullptr /* custom_kernel_creator */)); + + local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); + flib_runtime_ = pflr_->GetFLR(device_->name()); } XlaCompiler::~XlaCompiler() = default; @@ -137,8 +140,8 @@ Status XlaCompiler::CompileFunction( } const FunctionBody* fbody; - if (!GetFunctionBody(function, local_flib_runtime_.get(), &fbody).ok()) { - TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_.get(), &fbody)); + if (!GetFunctionBody(function, local_flib_runtime_, &fbody).ok()) { + TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_, &fbody)); } TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); @@ -159,7 +162,7 @@ Status XlaCompiler::CompileFunction( opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); - optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(), + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), /*device=*/nullptr, &graph, /*shape_map=*/nullptr); VLOG(1) << "===================================================="; @@ -464,7 +467,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, context->set_args(std::move(context_args)); TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, - flib_runtime_.get(), NextStepId())); + flib_runtime_, NextStepId())); int num_nonconst_outputs; int num_computation_outputs; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 317f635bcbe..b5987c8ac8b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -276,7 +276,7 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } XlaCompilationDevice* device() const { return device_; } const DeviceMgr* device_mgr() const { return &device_mgr_; } - FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_.get(); } + FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -303,9 +303,11 @@ class XlaCompiler { // library and runtime for functions created as part of the functionalize // control flow transformation. std::unique_ptr local_flib_def_; - std::unique_ptr local_flib_runtime_; + std::unique_ptr pflr_; + std::unique_ptr local_pflr_; - std::unique_ptr flib_runtime_; + FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_. + FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. struct SignatureHash { uint64 operator()( diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 8ed5c154bfd..25f00de81dd 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -241,6 +241,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 25b419557e5..d9a3079b87c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -21,6 +21,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:training", @@ -28,6 +29,27 @@ py_test( ], ) +py_test( + name = "iterator_ops_cluster_test", + size = "small", + srcs = ["iterator_ops_cluster_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "batch_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py new file mode 100644 index 00000000000..faad6e925d7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py @@ -0,0 +1,109 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops that need test_util.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.platform import test + + +class IteratorClusterTest(test.TestCase): + + def testRemoteIteratorWithoutRemoteCallFail(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + worker, _ = test_util.create_local_cluster( + 1, 1, worker_config=worker_config) + + with ops.device("/job:worker/replica:0/task:0/cpu:1"): + dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) + iterator_3 = dataset_3.make_one_shot_iterator() + iterator_3_handle = iterator_3.string_handle() + + with ops.device("/job:worker/replica:0/task:0/cpu:0"): + remote_it = dataset_ops.Iterator.from_string_handle( + iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes) + get_next_op = remote_it.get_next() + + with session.Session(worker[0].target) as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next_op) + + def testRemoteIteratorUsingRemoteCallOp(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + worker, _ = test_util.create_local_cluster( + 1, 1, worker_config=worker_config) + + with ops.device("/job:worker/replica:0/task:0/cpu:1"): + dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) + iterator_3 = dataset_3.make_one_shot_iterator() + iterator_3_handle = iterator_3.string_handle() + + @function.Defun(dtypes.string) + def _remote_fn(h): + remote_iterator = dataset_ops.Iterator.from_string_handle( + h, dataset_3.output_types, dataset_3.output_shapes) + return remote_iterator.get_next() + + with ops.device("/job:worker/replica:0/task:0/cpu:0"): + target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + remote_op = functional_ops.remote_call( + args=[iterator_3_handle], + Tout=[dtypes.int32], + f=_remote_fn, + target=target_placeholder) + + with session.Session(worker[0].target) as sess: + elem = sess.run( + remote_op, + feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + self.assertEqual(elem, [1]) + # Fails when target is cpu:0 where the resource is not located. + with self.assertRaises(errors.InvalidArgumentError): + sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:worker/replica:0/task:0/cpu:0" + }) + elem = sess.run( + remote_op, + feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + self.assertEqual(elem, [2]) + elem = sess.run( + remote_op, + feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + self.assertEqual(elem, [3]) + with self.assertRaises(errors.OutOfRangeError): + sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:worker/replica:0/task:0/cpu:1" + }) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 30f685842b0..b20742f7758 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -25,8 +25,10 @@ from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -416,6 +418,62 @@ class IteratorTest(test.TestCase): feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_float_vector})) + def testRemoteIteratorUsingRemoteCallOpDirectSession(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + with ops.device("/job:localhost/replica:0/task:0/cpu:1"): + dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) + iterator_3 = dataset_3.make_one_shot_iterator() + iterator_3_handle = iterator_3.string_handle() + + @function.Defun(dtypes.string) + def _remote_fn(h): + remote_iterator = dataset_ops.Iterator.from_string_handle( + h, dataset_3.output_types, dataset_3.output_shapes) + return remote_iterator.get_next() + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + remote_op = functional_ops.remote_call( + args=[iterator_3_handle], + Tout=[dtypes.int32], + f=_remote_fn, + target=target_placeholder) + + with self.test_session(config=worker_config) as sess: + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" + }) + self.assertEqual(elem, [1]) + # Fails when target is cpu:0 where the resource is not located. + with self.assertRaises(errors.InvalidArgumentError): + sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" + }) + self.assertEqual(elem, [2]) + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" + }) + self.assertEqual(elem, [3]) + with self.assertRaises(errors.OutOfRangeError): + sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" + }) + if __name__ == "__main__": test.main() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 8322f0a8975..f7b79e82e16 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -811,12 +811,14 @@ cc_library( name = "testlib", testonly = 1, srcs = [ + "common_runtime/function_testlib.cc", "common_runtime/kernel_benchmark_testlib.cc", "framework/fake_input.cc", "framework/function_testlib.cc", "graph/testlib.cc", ], hdrs = [ + "common_runtime/function_testlib.h", "common_runtime/kernel_benchmark_testlib.h", "framework/fake_input.h", "framework/function_testlib.h", @@ -2661,17 +2663,14 @@ tf_cc_test( ":test_main", ":testlib", "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:dense_update_ops", - "//tensorflow/core/kernels:fifo_queue_op", "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/kernels:queue_ops", - "//tensorflow/core/kernels:session_ops", - "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels:shape_ops", "//third_party/eigen3", ], ) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 6b529d8f133..4b239606a84 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -139,15 +139,14 @@ static Node* AddRet(Graph* g, Endpoint input, int index) { return ret; } -static const FunctionLibraryRuntime::Handle kInvalidHandle = -1; - class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { public: FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator); + CustomKernelCreator custom_kernel_creator, + ProcessFunctionLibraryRuntime* parent); ~FunctionLibraryRuntimeImpl() override; @@ -184,17 +183,13 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const FunctionLibraryDefinition* const lib_def_; GraphOptimizer optimizer_; const CustomKernelCreator custom_kernel_creator_; + const string device_name_; std::function get_func_sig_; std::function create_kernel_; mutable mutex mu_; - // Maps function instantiation to a handle. The key is a - // canonicalized representation of the function name and - // instantiation attrs. The handle is an index into the items_. - std::unordered_map table_ GUARDED_BY(mu_); - // func_graphs_ never shrinks or reorders its members. std::vector func_graphs_ GUARDED_BY(mu_); @@ -208,12 +203,16 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { }; std::vector items_; + ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned. + Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, FunctionBody** fbody); Status CreateItem(Handle handle, Item** item); Status GetOrCreateItem(Handle handle, Item** item); Status InstantiateSymbolicGradient(const NameAttrList& func, FunctionBody** g_body); + bool IsLocalTarget(const AttrSlice& attrs); + AttrValueMap FixAttrs(const AttrSlice& attrs); TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); }; @@ -222,14 +221,19 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator) + CustomKernelCreator custom_kernel_creator, + ProcessFunctionLibraryRuntime* parent) : device_mgr_(dmgr), device_(device), env_(env), graph_def_version_(graph_def_version), lib_def_(lib_def), optimizer_(optimizer_options), - custom_kernel_creator_(std::move(custom_kernel_creator)) { + custom_kernel_creator_(std::move(custom_kernel_creator)), + device_name_(device_ == nullptr + ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice + : device_->name()), + parent_(parent) { get_func_sig_ = [this](const string& op, const OpDef** sig) { return lib_def_->LookUpOpDef(op, sig); }; @@ -294,10 +298,17 @@ class CallOp : public AsyncOpKernel { }; const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { + LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h); + if (local_handle == kInvalidLocalHandle) { + LOG(ERROR) << "Could not find Handle: " << h + << " on device: " << device_name_; + return nullptr; + } + mutex_lock l(mu_); - CHECK_LE(static_cast(0), h); - CHECK_LT(h, func_graphs_.size()); - return func_graphs_[h]; + CHECK_LE(0, local_handle); + CHECK_LT(local_handle, func_graphs_.size()); + return func_graphs_[local_handle]; } Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, @@ -393,22 +404,47 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( return Status::OK(); } +bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) { + if (device_ == nullptr) return true; + string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); + if (target.empty()) return true; + return target == device_->name(); +} + +AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) { + AttrValueMap value_map; + for (auto it : attrs) { + value_map[it.first] = it.second; + } + if (attrs.Find("_target") != nullptr) { + return value_map; + } + AttrValue v; + v.set_s(device_name_); + AddAttr("_target", v, &value_map); + return value_map; +} + Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) { - const string key = Canonicalize(function_name, attrs); - { - mutex_lock l(mu_); - *handle = gtl::FindWithDefault(table_, key, kInvalidHandle); - if (*handle != kInvalidHandle) { - return Status::OK(); - } + AttrValueMap value_map = FixAttrs(attrs); + AttrSlice new_attrs(&value_map); + + if (!IsLocalTarget(new_attrs)) { + return parent_->Instantiate(function_name, new_attrs, handle); + } + + const string key = Canonicalize(function_name, new_attrs); + *handle = parent_->GetHandle(key); + if (*handle != kInvalidHandle) { + return Status::OK(); } Status s; FunctionBody* fbody = nullptr; if (function_name == kGradientOp) { - const AttrValue* f = attrs.Find(kFuncAttr); + const AttrValue* f = new_attrs.Find(kFuncAttr); if (f == nullptr) { return errors::InvalidArgument("SymbolicGradient is missing attr: f"); } @@ -426,17 +462,16 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, if (fdef == nullptr) { return errors::NotFound("Function ", function_name, " is not defined."); } - TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody)); + TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, new_attrs, &fbody)); } { mutex_lock l(mu_); - *handle = gtl::FindWithDefault(table_, key, kInvalidHandle); + *handle = parent_->GetHandle(key); if (*handle != kInvalidHandle) { delete fbody; } else { - *handle = func_graphs_.size(); - table_.insert({key, *handle}); + *handle = parent_->AddHandle(key, device_name_, func_graphs_.size()); func_graphs_.push_back(fbody); items_.resize(func_graphs_.size()); } @@ -494,13 +529,14 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { } Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { + LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); { mutex_lock l(mu_); - if (handle >= items_.size()) { + if (local_handle >= items_.size()) { return errors::NotFound("Function handle ", handle, " is not valid. Likely an internal error."); } - *item = items_[handle]; + *item = items_[local_handle]; if (*item != nullptr) { (*item)->Ref(); return Status::OK(); @@ -512,9 +548,9 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { { mutex_lock l(mu_); - if (items_[handle] == nullptr) { + if (items_[local_handle] == nullptr) { // Install *item in items_. - items_[handle] = *item; + items_[local_handle] = *item; (*item)->Ref(); } } @@ -528,6 +564,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { return done(errors::Cancelled("")); } + if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { + return parent_->Run(opts, handle, args, rets, done); + } const FunctionBody* fbody = GetFunctionBody(handle); FunctionCallFrame* frame = new FunctionCallFrame(fbody->arg_types, fbody->ret_types); @@ -616,19 +655,21 @@ std::unique_ptr NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator) { + CustomKernelCreator custom_kernel_creator, + ProcessFunctionLibraryRuntime* parent) { return std::unique_ptr(new FunctionLibraryRuntimeImpl( device_mgr, env, device, graph_def_version, lib_def, optimizer_options, - std::move(custom_kernel_creator))); + std::move(custom_kernel_creator), parent)); } std::unique_ptr NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options) { + const OptimizerOptions& optimizer_options, + ProcessFunctionLibraryRuntime* parent) { return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version, lib_def, optimizer_options, - GetCustomCreatorSingleton()->Get()); + GetCustomCreatorSingleton()->Get(), parent); } bool RemoveDeadNodes(Graph* g) { diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index 167f0955970..477340d87a3 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -36,9 +37,6 @@ static constexpr const char* const kNoInlineAttr = "_noinline"; // takes ownership of the returned OpKernel. // // TODO(zhifengc/phawkins): b/32379046 -typedef std::function*)> - CustomKernelCreator; void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb); // Creates a FunctionLibraryRuntime, which instantiates functions @@ -50,11 +48,16 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb); // The returned object does not take ownerships of "device" or // "lib_def". The caller must ensure "device" and "lib_def" outlives // the returned object. +// +// The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that +// typically owns the created FunctionLibraryRuntime object. The parent pointer +// is not owned by the FunctionLibraryRuntime object. std::unique_ptr NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator); + CustomKernelCreator custom_kernel_creator, + ProcessFunctionLibraryRuntime* parent); // Same as above except that the returned runtime consults with the // global default custom kernel creator registered by @@ -62,7 +65,8 @@ std::unique_ptr NewFunctionLibraryRuntime( std::unique_ptr NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options); + const OptimizerOptions& optimizer_options, + ProcessFunctionLibraryRuntime* parent); // FunctionLibraryRuntime::GetFunctionBody returns a description of an // instantiated function that is represented as a Graph with arg/ret diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 3ca4457b00c..a9f06c4df03 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include +#include #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -34,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -49,40 +50,18 @@ Status GetOpSig(const string& op, const OpDef** sig) { return OpRegistry::Global()->LookUpOpDef(op, sig); } -void FunctionTestSchedClosure(std::function fn) { - static thread::ThreadPool* w = - new thread::ThreadPool(Env::Default(), "Test", 8); - w->Schedule(std::move(fn)); -} - void HasError(const Status& s, const string& substr) { EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) << s << ", expected substring " << substr; } -// A helper class to make AttrSlice from initializer lists -class Attrs { - public: - Attrs(const std::initializer_list< // NOLINT(runtime/explicit) - std::pair>& attrs) { - for (const auto& aval : attrs) { - map_.insert({aval.first, aval.second.proto}); - } - } - - operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) - - private: - AttrValueMap map_; -}; - class FunctionTest : public ::testing::Test { protected: FunctionTest() : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - void Create(const FunctionDef& fdef, Attrs attrs) { + void Create(const FunctionDef& fdef, test::function::Attrs attrs) { exec_ = nullptr; InstantiationResult result; TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); @@ -117,7 +96,7 @@ class FunctionTest : public ::testing::Test { TF_CHECK_OK(frame.SetArgs(args)); Executor::Args exec_args; exec_args.call_frame = &frame; - exec_args.runner = FunctionTestSchedClosure; + exec_args.runner = test::function::FunctionTestSchedClosure; TF_CHECK_OK(exec_->Run(exec_args)); std::vector computed; TF_CHECK_OK(frame.GetRetvals(&computed)); @@ -154,41 +133,42 @@ TEST_F(FunctionTest, WXPlusB) { class FunctionLibraryRuntimeTest : public ::testing::Test { protected: - FunctionLibraryRuntimeTest() - : device_(DeviceFactory::NewDevice("CPU", {}, - "/job:localhost/replica:0/task:0")) {} - void Init(const std::vector& flib) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 3}); + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices_)); + FunctionDefLibrary proto; for (const auto& fdef : flib) *(proto.add_function()) = fdef; lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; - lib_ = - NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(), - TF_GRAPH_DEF_VERSION, lib_def_.get(), opts); + device_mgr_.reset(new DeviceMgr(devices_)); + pflr_.reset(new ProcessFunctionLibraryRuntime( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts)); + flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); + flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); fdef_lib_ = lib_def_->ToProto(); } - Status Run(const string& name, Attrs attrs, const std::vector& args, - std::vector rets) { - FunctionLibraryRuntime::Handle handle; - Status status = lib_->Instantiate(name, attrs, &handle); - if (!status.ok()) { - return status; - } - + Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, + const std::vector& args, std::vector rets) { std::atomic call_count(0); std::function)> runner = [&call_count](std::function fn) { ++call_count; - FunctionTestSchedClosure(fn); + test::function::FunctionTestSchedClosure(fn); }; Notification done; FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector out; - lib_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { + Status status; + flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) { status = s; done.Notify(); }); @@ -206,28 +186,54 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return Status::OK(); } - std::unique_ptr GetFuncBody(const string& name, Attrs attrs) { + Status Instantiate(FunctionLibraryRuntime* flr, const string& name, + test::function::Attrs attrs, + FunctionLibraryRuntime::Handle* handle) { + Status status = flr->Instantiate(name, attrs, handle); + if (!status.ok()) { + return status; + } + return Status::OK(); + } + + Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name, + test::function::Attrs attrs, + const std::vector& args, + std::vector rets) { FunctionLibraryRuntime::Handle handle; - Status status = lib_->Instantiate(name, attrs, &handle); + Status status = flr->Instantiate(name, attrs, &handle); + if (!status.ok()) { + return status; + } + return Run(flr, handle, args, std::move(rets)); + } + + std::unique_ptr GetFuncBody(FunctionLibraryRuntime* flr, + const string& name, + test::function::Attrs attrs) { + FunctionLibraryRuntime::Handle handle; + Status status = flr->Instantiate(name, attrs, &handle); if (!status.ok()) { LOG(ERROR) << status; return nullptr; } - const FunctionBody* fbody = lib_->GetFunctionBody(handle); + const FunctionBody* fbody = flr->GetFunctionBody(handle); CHECK_NOTNULL(fbody); std::unique_ptr ret(new Graph(lib_def_.get())); CopyGraph(*fbody->graph, ret.get()); return ret; } - std::unique_ptr GetGradBody(const string& func, Attrs attrs) { + std::unique_ptr GetGradBody(FunctionLibraryRuntime* flr, + const string& func, + test::function::Attrs attrs) { FunctionLibraryRuntime::Handle handle; - Status status = lib_->Instantiate(func, attrs, &handle); + Status status = flr->Instantiate(func, attrs, &handle); if (!status.ok()) { LOG(ERROR) << status; return nullptr; } - const FunctionBody* fbody = lib_->GetFunctionBody(handle); + const FunctionBody* fbody = flr->GetFunctionBody(handle); CHECK_NOTNULL(fbody); std::unique_ptr gbody(SymbolicGradient(*fbody)); CHECK_NOTNULL(gbody); @@ -236,24 +242,29 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return ret; } - std::unique_ptr device_; + FunctionLibraryRuntime* flr0_; + FunctionLibraryRuntime* flr1_; + FunctionLibraryRuntime* flr2_; + std::vector devices_; + std::unique_ptr device_mgr_; std::unique_ptr lib_def_; - std::unique_ptr lib_; + std::unique_ptr pflr_; FunctionDefLibrary fdef_lib_; }; TEST_F(FunctionLibraryRuntimeTest, IsStateful) { Init({}); - EXPECT_TRUE(lib_->IsStateful("Variable")); - EXPECT_TRUE(lib_->IsStateful("VariableV2")); - EXPECT_FALSE(lib_->IsStateful("Matmul")); + EXPECT_TRUE(flr0_->IsStateful("Variable")); + EXPECT_TRUE(flr0_->IsStateful("VariableV2")); + EXPECT_FALSE(flr0_->IsStateful("Matmul")); } TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) { Init({test::function::XTimesTwo()}); auto x = test::AsTensor({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK( + InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); } @@ -262,11 +273,14 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesN) { test::function::XTimes16()}); auto x = test::AsTensor({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK( + InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); - TF_CHECK_OK(Run("XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK( + InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); - TF_CHECK_OK(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK( + InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({16, 32, 48, 64})); } @@ -294,7 +308,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name, TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); { @@ -312,7 +326,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -334,7 +348,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); GraphDef e2; { Scope s = Scope::NewRootScope(); @@ -373,7 +387,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { } // No further inlining. - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { GraphDef actual; g->ToGraphDef(&actual); @@ -425,7 +439,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TF_ASSERT_OK(s.ToGraph(g.get())); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -449,7 +463,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -495,10 +509,10 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); - ExpandInlineFunctions(lib_.get(), g.get()); - OptimizeGraph(lib_.get(), &g); + ExpandInlineFunctions(flr0_, g.get()); + OptimizeGraph(flr0_, &g); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -541,9 +555,9 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { // Return {{"o", "g:output"}}); Init({test::function::Swap(), func}); - std::unique_ptr g = GetFuncBody("ManySwapsNodeDef", {}); + std::unique_ptr g = GetFuncBody(flr0_, "ManySwapsNodeDef", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(lib_.get(), &g); + OptimizeGraph(flr0_, &g); const char* e0 = R"P( (n3:float, n2:float) -> (n3:float) { } @@ -574,9 +588,9 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}}, {{"o", "o:z:0"}}); Init({test::function::Swap(), func}); - std::unique_ptr g = GetFuncBody("ManySwapsFirst", {}); + std::unique_ptr g = GetFuncBody(flr0_, "ManySwapsFirst", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(lib_.get(), &g); + OptimizeGraph(flr0_, &g); // NOTE: We can remove func0, func1, func2, func9 with a control edge n8->n5. // But we don't have a pass doing that. @@ -609,7 +623,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { Init({test::function::XTimesTwo(), test::function::XTimesFour()}); auto x = test::AsTensor({1, 2, 3, 4}); Tensor y; - HasError(Run("Foo", {{"T", DT_FLOAT}}, {x}, {&y}), + HasError(InstantiateAndRun(flr0_, "Foo", {{"T", DT_FLOAT}}, {x}, {&y}), "Not found: Function Foo is not defined."); } @@ -632,25 +646,27 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { // Instantiating "XTimesTwo" should fail. FunctionLibraryRuntime::Handle handle; - HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle), + HasError(flr0_->Instantiate( + "XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle), "Not found: type attr not found"); // But XTimesFour and XTimes16 instantiation should succeed. Only // when they run, they fail because XTimesTwo is bad. - TF_CHECK_OK( - lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle)); - TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(flr0_->Instantiate( + "XTimesFour", test::function::Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(flr0_->Instantiate( + "XTimes16", test::function::Attrs({{"T", DT_FLOAT}}), &handle)); auto x = test::AsTensor({1, 2, 3, 4}); Tensor y; - HasError(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), + HasError(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), "type attr not found"); } TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr f = GetFuncBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -666,7 +682,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { TF_EXPECT_GRAPH_EQ(expected, actual); } - std::unique_ptr g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr g = GetGradBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}); { Scope s = Scope::NewRootScope(); @@ -690,7 +706,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { TF_EXPECT_GRAPH_EQ(expected, actual); } - OptimizeGraph(lib_.get(), &g); + OptimizeGraph(flr0_, &g); { Scope s = Scope::NewRootScope(); @@ -726,7 +742,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) { Init({}); auto T = DT_FLOAT; std::unique_ptr g = GetFuncBody( - "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); + flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -756,7 +772,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) { Init({}); auto T = DT_FLOAT; std::unique_ptr g = GetFuncBody( - "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); + flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -812,7 +828,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { Init({test, grad}); - std::unique_ptr g = GetFuncBody("TestGrad", {}); + std::unique_ptr g = GetFuncBody(flr0_, "TestGrad", {}); ASSERT_TRUE(g != nullptr); { Scope s = Scope::NewRootScope(); @@ -836,7 +852,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -888,7 +904,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TF_EXPECT_GRAPH_EQ(expected, actual); } - OptimizeGraph(lib_.get(), &g); + OptimizeGraph(flr0_, &g); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -939,6 +955,25 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { } } +TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { + Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Handle handle; + TF_CHECK_OK(Instantiate( + flr0_, "FindDevice", + {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); + + Tensor y; + // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. + TF_CHECK_OK(Run(flr1_, handle, {}, {&y})); + test::ExpectTensorEqual( + y, test::AsTensor({"/job:localhost/replica:0/task:0/cpu:1"}, + TensorShape({}))); + TF_CHECK_OK(Run(flr2_, handle, {}, {&y})); + test::ExpectTensorEqual( + y, test::AsTensor({"/job:localhost/replica:0/task:0/cpu:1"}, + TensorShape({}))); +} + namespace { bool DoNothing(Graph* g) { return false; } diff --git a/tensorflow/core/common_runtime/function_testlib.cc b/tensorflow/core/common_runtime/function_testlib.cc new file mode 100644 index 00000000000..64e59762a2a --- /dev/null +++ b/tensorflow/core/common_runtime/function_testlib.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/function_testlib.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace test { +namespace function { + +typedef FunctionDefHelper FDH; + +class FindDeviceOpKernel : public OpKernel { + public: + explicit FindDeviceOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + Tensor* device_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("device_name", TensorShape{}, + &device_tensor)); + device_tensor->scalar()() = + ctx->function_library()->device()->name(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FindDeviceOp").Device(tensorflow::DEVICE_CPU), + FindDeviceOpKernel); +REGISTER_OP("FindDeviceOp").Output("device_name: string"); + +FunctionDef FindDevice() { + return FDH::Define( + // Name + "FindDevice", + // Args + {}, + // Return values + {"device_name: string"}, + // Attr def + {}, + // Nodes + {{{"device_name"}, "FindDeviceOp", {}, {}}}); +} + +} // namespace function +} // namespace test +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/function_testlib.h b/tensorflow/core/common_runtime/function_testlib.h new file mode 100644 index 00000000000..6b93b188b71 --- /dev/null +++ b/tensorflow/core/common_runtime/function_testlib.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ + +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { +namespace test { +namespace function { + +// {} -> y:DT_STRING (device where this op runs). +FunctionDef FindDevice(); + +} // namespace function +} // namespace test +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 97d891fa16a..0caec036252 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -14,21 +14,58 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include + #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace tensorflow { +const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null"; + ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options) { - if (!device_mgr) return; + if (device_mgr == nullptr) { + flr_map_[kDefaultFLRDevice] = + NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version, + lib_def, optimizer_options, this); + return; + } + for (Device* d : device_mgr->ListDevices()) { + flr_map_[d->name()] = + NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version, + lib_def, optimizer_options, this); + } +} + +ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator) { + if (device_mgr == nullptr) { + flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime( + nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options, + custom_kernel_creator, this); + } for (Device* d : device_mgr->ListDevices()) { flr_map_[d->name()] = NewFunctionLibraryRuntime( - device_mgr, env, d, graph_def_version, lib_def, optimizer_options); + device_mgr, env, d, graph_def_version, lib_def, optimizer_options, + custom_kernel_creator, this); } } +string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( + const AttrSlice& attrs) { + const AttrValue* value; + if (!attrs.Find("_target", &value).ok()) { + return ""; + } + return value->s(); +} + FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( const string& device_name) { if (flr_map_.find(device_name) == flr_map_.end()) { @@ -38,4 +75,70 @@ FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( return flr_map_[device_name].get(); } +FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle( + const string& function_key, const string& device_name, + FunctionLibraryRuntime::LocalHandle local_handle) { + mutex_lock l(mu_); + FunctionLibraryRuntime::Handle h = + gtl::FindWithDefault(table_, function_key, kInvalidHandle); + if (h != kInvalidHandle) { + return h; + } + h = function_data_.size(); + function_data_.emplace_back(device_name, local_handle); + table_[function_key] = h; + return h; +} + +FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle( + const string& function_key) const { + mutex_lock l(mu_); + return gtl::FindWithDefault(table_, function_key, kInvalidHandle); +} + +bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice( + const string& device_name, FunctionLibraryRuntime::Handle handle) { + return GetHandleOnDevice(device_name, handle) != -1; +} + +FunctionLibraryRuntime::LocalHandle +ProcessFunctionLibraryRuntime::GetHandleOnDevice( + const string& device_name, FunctionLibraryRuntime::Handle handle) { + mutex_lock l(mu_); + std::pair p = + function_data_[handle]; + if (p.first != device_name) { + return kInvalidLocalHandle; + } + return p.second; +} + +Status ProcessFunctionLibraryRuntime::Instantiate( + const string& function_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle) { + string target = ObtainFunctionTarget(attrs); + + FunctionLibraryRuntime* flr = GetFLR(target); + if (flr != nullptr) { + return flr->Instantiate(function_name, attrs, handle); + } + return errors::InvalidArgument("Target: ", target, " is not supported"); +} + +void ProcessFunctionLibraryRuntime::Run( + const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, gtl::ArraySlice args, + std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { + FunctionLibraryRuntime* flr = nullptr; + { + mutex_lock l(mu_); + std::pair p = + function_data_[handle]; + flr = GetFLR(p.first); + } + if (flr != nullptr) { + return flr->Run(opts, handle, args, rets, std::move(done)); + } +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 53b2223b28f..2259997005e 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -24,7 +24,6 @@ limitations under the License. namespace tensorflow { // A class that stores all the FunctionLibraryRuntime objects, one per device. -// This class is not thread safe. class ProcessFunctionLibraryRuntime { public: // Creates FunctionLibraryRuntime objects for each device in the provided @@ -35,10 +34,64 @@ class ProcessFunctionLibraryRuntime { const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options); + ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, + int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator); + + // Given a list of attrs on a function, extracts the "_target" attribute which + // indicates which device to run the function on. If it can't find the _target + // attribute, returns "". Canonicalizes the device name. + static string ObtainFunctionTarget(const AttrSlice& attrs); + + static const char kDefaultFLRDevice[]; // Returns the FunctionLibraryRuntime for the corresponding device_name. FunctionLibraryRuntime* GetFLR(const string& device_name); + // For a given canonicalized key signature of the function instantiated + // on device `device_name` and a `local_handle`, creates a handle and returns + // that value. Use core/common_runtime/framework/function.h::Canonicalize + // to canonicalize the function signature. + FunctionLibraryRuntime::Handle AddHandle( + const string& function_key, const string& device_name, + FunctionLibraryRuntime::LocalHandle local_handle); + + // Returns a handle if found for the given key, else returns kInvalidHandle. + FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; + + // For the given handle instantiated on device `device_name` returns the local + // index of instantiation of that function. If the function was not + // instantiated on `device_name` returns kInvalidLocalHandle. + FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( + const string& device_name, FunctionLibraryRuntime::Handle handle); + + // Returns true if function with handle `handle` was instantiated on device + // `device_name`. + bool IsInstantiatedOnDevice(const string& device_name, + FunctionLibraryRuntime::Handle handle); + + // Instantiates the function. See framework/function.h for more details. + // Allows for function_name to be instantiated on different devices + // as specified in attrs. + Status Instantiate(const string& function_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle); + + // Runs the function with given `handle`. Function could have been + // instantiated on any device. More details in framework/function.h + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, gtl::ArraySlice args, + std::vector* rets, + FunctionLibraryRuntime::DoneCallback done); + private: + mutable mutex mu_; + + // Holds all the function invocations here. + std::unordered_map table_ + GUARDED_BY(mu_); + std::vector> + function_data_ GUARDED_BY(mu_); std::unordered_map> flr_map_; }; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index d9a5cab88b9..1536aedde58 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function_testlib.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -25,8 +28,8 @@ namespace tensorflow { namespace { class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { - public: - ProcessFunctionLibraryRuntimeTest() { + protected: + void Init(const std::vector& flib) { SessionOptions options; auto* device_count = options.config.mutable_device_count(); device_count->insert({"CPU", 2}); @@ -34,6 +37,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { &devices_)); device_mgr_.reset(new DeviceMgr(devices_)); FunctionDefLibrary proto; + for (const auto& fdef : flib) *(proto.add_function()) = fdef; lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; proc_flr_.reset(new ProcessFunctionLibraryRuntime( @@ -41,7 +45,43 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { opts)); } - protected: + Status Run(const string& name, test::function::Attrs attrs, + const std::vector& args, std::vector rets) { + FunctionLibraryRuntime::Handle handle; + Status status = proc_flr_->Instantiate(name, attrs, &handle); + if (!status.ok()) { + return status; + } + + std::atomic call_count(0); + std::function)> runner = + [&call_count](std::function fn) { + ++call_count; + test::function::FunctionTestSchedClosure(fn); + }; + + Notification done; + FunctionLibraryRuntime::Options opts; + opts.runner = &runner; + std::vector out; + proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { + status = s; + done.Notify(); + }); + done.WaitForNotification(); + if (!status.ok()) { + return status; + } + CHECK_EQ(rets.size(), out.size()); + for (size_t i = 0; i < rets.size(); ++i) { + *rets[i] = out[i]; + } + + EXPECT_GE(call_count, 1); // Test runner is used. + + return Status::OK(); + } + std::vector devices_; std::unique_ptr device_mgr_; std::unique_ptr lib_def_; @@ -49,6 +89,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { }; TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { + Init({}); FunctionLibraryRuntime* flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0"); EXPECT_NE(flr, nullptr); @@ -60,5 +101,87 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { EXPECT_EQ(flr, nullptr); } +TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { + AttrSlice empty_attrs; + string target = + ProcessFunctionLibraryRuntime::ObtainFunctionTarget(empty_attrs); + EXPECT_EQ("", target); + + AttrValueMap attr_values; + AttrValue v; + v.set_s("/job:a/replica:0/task:0/cpu:1"); + AddAttr("_target", v, &attr_values); + AttrSlice attrs(&attr_values); + target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); + EXPECT_EQ("/job:a/replica:0/task:0/cpu:1", target); +} + +TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { + Init({test::function::XTimesTwo()}); + auto x = test::AsTensor({1, 2, 3, 4}); + Tensor y; + TF_CHECK_OK( + Run("XTimesTwo", + {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, + {&y})); + test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); +} + +TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { + Init({test::function::FindDevice()}); + Tensor y; + TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, + {}, {&y})); + test::ExpectTensorEqual( + y, test::AsTensor({"/job:a/replica:0/task:0/cpu:0"}, + TensorShape({}))); +} + +TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { + Init({test::function::XTimesTwo(), test::function::XTimesFour()}); + auto x = test::AsTensor({1, 2, 3, 4}); + Tensor y; + TF_CHECK_OK( + Run("XTimesTwo", + {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, + {&y})); + test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); + TF_CHECK_OK( + Run("XTimesFour", + {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, + {&y})); + test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); +} + +TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { + Init({test::function::FindDevice()}); + Tensor y; + TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, + {}, {&y})); + test::ExpectTensorEqual( + y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, + TensorShape({}))); + TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, + {}, {&y})); + test::ExpectTensorEqual( + y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, + TensorShape({}))); +} + +TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { + Init({test::function::FindDevice()}); + Tensor y; + TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, + {}, {&y})); + test::ExpectTensorEqual( + y, test::AsTensor({"/job:a/replica:0/task:0/cpu:0"}, + TensorShape({}))); + TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, + {}, {&y})); + test::ExpectTensorEqual( + y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, + TensorShape({}))); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 045976dd06a..717f0c85755 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -437,8 +437,16 @@ class FunctionLibraryRuntime { // Returns the graph version number. virtual int graph_def_version() = 0; + + typedef uint64 LocalHandle; }; +const FunctionLibraryRuntime::Handle kInvalidHandle = -1; +const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; +typedef std::function*)> + CustomKernelCreator; + // To register a gradient function for a builtin op, one should use // REGISTER_OP_GRADIENT(, ); // diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 4ee23226daa..e6ef8425fb0 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -172,6 +173,12 @@ FunctionDef Swap() { {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); } +void FunctionTestSchedClosure(std::function fn) { + static thread::ThreadPool* w = + new thread::ThreadPool(Env::Default(), "Test", 8); + w->Schedule(std::move(fn)); +} + } // end namespace function } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index 49e5b0c99d9..a742fe0ce7a 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -30,6 +30,22 @@ namespace tensorflow { namespace test { namespace function { +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair>& attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; + // Helper to construct a NodeDef. NodeDef NDef( const string& name, const string& op, gtl::ArraySlice inputs, @@ -62,6 +78,8 @@ FunctionDef NonZero(); // x:T, y:T -> y:T, x:T FunctionDef Swap(); +void FunctionTestSchedClosure(std::function fn); + } // end namespace function } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 61366514102..b740e8a999e 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -104,9 +104,11 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, optimizer_opts->set_do_function_inlining(cfg.inline_functions); // Create the function library runtime. - std::unique_ptr flib(NewFunctionLibraryRuntime( - dvc_mgr.get(), env, devices[0], inlined_graph_def.versions().producer(), - &function_library, *optimizer_opts)); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, + inlined_graph_def.versions().producer(), + &function_library, *optimizer_opts)); + FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name()); // Create the GraphOptimizer to optimize the graph def. GraphConstructorOptions graph_ctor_opts; @@ -122,8 +124,7 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, // Optimize the graph. GraphOptimizer optimizer(*optimizer_opts); - optimizer.Optimize(flib.get(), env, devices[0], &graphptr, - /*shape_map=*/nullptr); + optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr); graphptr->ToGraphDef(output_graph_def); return Status::OK(); diff --git a/tensorflow/core/kernels/captured_function.cc b/tensorflow/core/kernels/captured_function.cc index eb52de6d85e..15e9680f262 100644 --- a/tensorflow/core/kernels/captured_function.cc +++ b/tensorflow/core/kernels/captured_function.cc @@ -40,9 +40,9 @@ Status CapturedFunction::Create( // NOTE(mrry): We need to assign a name to the device, and we choose // the same name as the calling context's device so that we do not // need to rewrite resource handles that are found in `captured_inputs`. - std::unique_ptr device(new ThreadPoolDevice( - SessionOptions(), ctx->device()->attributes().name(), Bytes(256 << 20), - DeviceLocality(), cpu_allocator())); + Device* device = + new ThreadPoolDevice(SessionOptions(), ctx->device()->attributes().name(), + Bytes(256 << 20), DeviceLocality(), cpu_allocator()); // TODO(mrry): Handle arbitrary resource types, which might require a // redesign (or opening up access to `ResourceMgr::DoLookup()` and @@ -82,20 +82,24 @@ Status CapturedFunction::Create( } #undef HANDLE_RESOURCE_TYPE + std::unique_ptr device_mgr(new DeviceMgr({device})); std::unique_ptr flib_def( new FunctionLibraryDefinition( *ctx->function_library()->GetFunctionLibraryDefinition())); - std::unique_ptr lib(NewFunctionLibraryRuntime( - nullptr /* device_mgr */, ctx->env(), device.get(), graph_def_version, - flib_def.get(), {} /* TODO(mrry): OptimizerOptions? */)); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + device_mgr.get(), ctx->env(), graph_def_version, flib_def.get(), + {} /* TODO(mrry): OptimizerOptions? */)); + + FunctionLibraryRuntime* lib = pflr->GetFLR(device->name()); FunctionLibraryRuntime::Handle f_handle; TF_RETURN_IF_ERROR( lib->Instantiate(func->name(), AttrSlice(&func->attr()), &f_handle)); out_function->reset(new CapturedFunction( - std::move(device), std::move(flib_def), std::move(lib), f_handle, - std::move(captured_inputs))); + device, std::move(device_mgr), std::move(flib_def), std::move(pflr), lib, + f_handle, std::move(captured_inputs))); return Status::OK(); } @@ -136,14 +140,16 @@ Status CapturedFunction::Run(FunctionLibraryRuntime::Options f_opts, } CapturedFunction::CapturedFunction( - std::unique_ptr device, + Device* device, std::unique_ptr device_mgr, std::unique_ptr flib_def, - std::unique_ptr lib, - FunctionLibraryRuntime::Handle f_handle, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, std::vector captured_inputs) - : device_(std::move(device)), + : device_(device), + device_mgr_(std::move(device_mgr)), flib_def_(std::move(flib_def)), - lib_(std::move(lib)), + pflr_(std::move(pflr)), + lib_(lib), f_handle_(f_handle), captured_inputs_(std::move(captured_inputs)) {} diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/core/kernels/captured_function.h index e24bcb9d829..03679736f35 100644 --- a/tensorflow/core/kernels/captured_function.h +++ b/tensorflow/core/kernels/captured_function.h @@ -63,20 +63,23 @@ class CapturedFunction { gtl::ArraySlice args, std::vector* rets, const string& prefix); - Device* device() const { return device_.get(); } + const Device* device() const { return device_; } ResourceMgr* resource_manager() const { return device_->resource_manager(); } private: - CapturedFunction(std::unique_ptr device, + CapturedFunction(Device* device, std::unique_ptr device_mgr, std::unique_ptr flib_def, - std::unique_ptr lib, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, std::vector captured_inputs); - const std::unique_ptr device_; + Device* const device_; // owned by device_mgr_. + const std::unique_ptr device_mgr_; const std::unique_ptr flib_def_; - const std::unique_ptr lib_; + const std::unique_ptr pflr_; + FunctionLibraryRuntime* const lib_; // owned by pflr_. const FunctionLibraryRuntime::Handle f_handle_; const std::vector captured_inputs_; diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index b831b5bff55..a1dfd4c3d31 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -277,5 +277,67 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU), REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL), SymbolicGradientOp); +#endif // TENSORFLOW_USE_SYCL + +class RemoteCallOp : public AsyncOpKernel { + public: + explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + } + + ~RemoteCallOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor* target; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); + AttrValueMap attr_values = func_->attr(); + AttrValue v; + v.set_s(target->scalar()()); + AddAttr("_target", v, &attr_values); + + FunctionLibraryRuntime* lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library is provided."), + done); + FunctionLibraryRuntime::Handle handle; + OP_REQUIRES_OK_ASYNC( + ctx, lib->Instantiate(func_->name(), AttrSlice(&attr_values), &handle), + done); + + OpInputList arguments; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); + + FunctionLibraryRuntime::Options opts; + opts.step_id = ctx->step_id(); + opts.runner = ctx->runner(); + std::vector args; + args.reserve(arguments.size()); + for (const Tensor& argument : arguments) { + args.push_back(argument); + } + auto* rets = new std::vector; + lib->Run(opts, handle, args, rets, [rets, done, ctx](const Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); + } + delete rets; + done(); + }); + } + + private: + string target_; + const NameAttrList* func_; + TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); +}; + +REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp); +REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp); + #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index d1f9e949425..5fd21ec88fa 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -65,4 +66,22 @@ to x_i. (Needs some math expert to say the comment above better.) )doc"); +REGISTER_OP("RemoteCall") + .Input("target: string") + .Input("args: Tin") + .Output("output: Tout") + .Attr("Tin: list(type)") + .Attr("Tout: list(type)") + .Attr("f: func") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Runs function `f` on a remote device indicated by `target`. + +target: A fully specified device name where we want to run the function. +args: A list of arguments for the function. +output: A list of return values. +Tin: The type list for the arguments. +Tout: The type list for the return values. +f: The function to run remotely. +)doc"); } // end namespace tensorflow diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index e0231c460e8..a7bedc7199c 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -20,10 +20,14 @@ from __future__ import print_function import numpy as np +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl @@ -446,6 +450,56 @@ class FunctionalOpsTest(test.TestCase): sess.run([result, result_t, result_grad, result_t_grad], feed_dict={x: [[1.0, 2.0]]}) + def testRemoteFunction(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + worker, _ = test_util.create_local_cluster( + 1, 1, worker_config=worker_config) + + @function.Defun(dtypes.int32, dtypes.int32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/job:ps/task:0"): + a = variables.Variable(2, dtype=dtypes.int32) + b = variables.Variable(3, dtype=dtypes.int32) + + with ops.device("/job:worker/replica:0/task:0/cpu:0"): + remote_op = functional_ops.remote_call( + args=[a, b], + Tout=[dtypes.int32], + f=_remote_fn, + target="/job:worker/replica:0/task:0/cpu:1") + + with session.Session(worker[0].target) as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, [6]) + + def testRemoteFunctionDirectSession(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + @function.Defun(dtypes.int32, dtypes.int32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + a = variables.Variable(2, dtype=dtypes.int32) + b = variables.Variable(3, dtype=dtypes.int32) + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + remote_op = functional_ops.remote_call( + args=[a, b], + Tout=[dtypes.int32], + f=_remote_fn, + target="/job:localhost/replica:0/task:0/cpu:1") + + with self.test_session(config=worker_config) as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, [6]) + if __name__ == "__main__": test.main() From e31346452d91c48fa9b3deff8df575ccbd7f877a Mon Sep 17 00:00:00 2001 From: Jonathan Hseu Date: Thu, 17 Aug 2017 17:26:07 -0700 Subject: [PATCH 05/70] TPUEstimator: Fix the outfeed thread join. PiperOrigin-RevId: 165651781 --- tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index c7b84f952f9..3622dff29b9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -283,7 +283,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): if self._dequeue_ops is not None: logging.info('Stop output thread controller') - self._infeed_thd_controller.join() + self._outfeed_thd_controller.join() logging.info('Shutdown TPU system.') session.run(self._finalize_op) From 641943fd71c6e42ff3d6c71af45199dea4895976 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Aug 2017 17:35:54 -0700 Subject: [PATCH 06/70] Update ops-related pbtxt files. PiperOrigin-RevId: 165652758 --- .../core/ops/compat/ops_history.v1.pbtxt | 31 +++++++++++++++ tensorflow/core/ops/ops.pbtxt | 38 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index e76573ffdb1..6ff1a3fc038 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -19996,6 +19996,37 @@ op { } } } +op { + name: "RemoteCall" + input_arg { + name: "target" + type: DT_STRING + } + input_arg { + name: "args" + type_list_attr: "Tin" + } + output_arg { + name: "output" + type_list_attr: "Tout" + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "f" + type: "func" + } +} op { name: "RemoteFusedGraphExecute" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 06eabdcdcd6..87cdc30fb1b 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -19607,6 +19607,44 @@ op { } summary: "Computes rectified linear gradients for a Relu operation." } +op { + name: "RemoteCall" + input_arg { + name: "target" + description: "A fully specified device name where we want to run the function." + type: DT_STRING + } + input_arg { + name: "args" + description: "A list of arguments for the function." + type_list_attr: "Tin" + } + output_arg { + name: "output" + description: "A list of return values." + type_list_attr: "Tout" + } + attr { + name: "Tin" + type: "list(type)" + description: "The type list for the arguments." + has_minimum: true + minimum: 1 + } + attr { + name: "Tout" + type: "list(type)" + description: "The type list for the return values." + has_minimum: true + minimum: 1 + } + attr { + name: "f" + type: "func" + description: "The function to run remotely." + } + summary: "Runs function `f` on a remote device indicated by `target`." +} op { name: "RemoteFusedGraphExecute" input_arg { From 465c408196210efcdeb792b72801fdec7b7db868 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Aug 2017 17:44:32 -0700 Subject: [PATCH 07/70] Fix the shape information propagation for Enter op. PiperOrigin-RevId: 165653579 --- tensorflow/core/ops/control_flow_ops.cc | 7 +++++++ .../python/kernel_tests/control_flow_ops_py_test.py | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc index 9e39b396e1f..61089658d71 100644 --- a/tensorflow/core/ops/control_flow_ops.cc +++ b/tensorflow/core/ops/control_flow_ops.cc @@ -204,6 +204,13 @@ REGISTER_OP("Enter") auto* handle_data = c->input_handle_shapes_and_types(0); if (handle_data != nullptr) { c->set_output_handle_shapes_and_types(0, *handle_data); + } else { + // Otherwise, propagate shape if output is a constant. + bool is_constant; + TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant)); + if (is_constant) { + c->set_output(0, c->input(0)); + } } return Status::OK(); diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index fdecea1dc10..a43fe71b9f3 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -179,6 +179,19 @@ class ControlFlowTest(test.TestCase): result = exit_op.eval() self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) + def testEnterShapePropagation(self): + with self.test_session(): + v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) + + # If is_constant=True, the shape information should be propagated. + enter_v_constant = control_flow_ops.enter(v, "frame1", is_constant=True) + self.assertEqual(enter_v_constant.shape, [2]) + + # Otherwise, the shape should be unknown. + enter_v_non_constant = control_flow_ops.enter(v, "frame2", + is_constant=False) + self.assertEqual(enter_v_non_constant.shape, None) + def testSwitchMergeIndexedSlices(self): with self.test_session(): values = constant_op.constant([1, 2, 3, 4, 5, 6]) From d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Aug 2017 17:48:29 -0700 Subject: [PATCH 08/70] Fix linear algebra benchmarks. PiperOrigin-RevId: 165653891 --- .../python/kernel_tests/cholesky_op_test.py | 105 +++++++++++------- .../kernel_tests/determinant_op_test.py | 37 +++--- .../kernel_tests/matrix_inverse_op_test.py | 41 ++++--- 3 files changed, 110 insertions(+), 73 deletions(-) diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index d783522e820..de80fb30554 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -81,8 +82,11 @@ def MatrixInverseCompositeGrad(l, grad): def TriAngInvCompositeGrad(l, grad): num_rows = array_ops.shape(l)[-1] batch_shape = array_ops.shape(l)[:-2] - l_inverse = linalg_ops.matrix_triangular_solve( - l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype)) + l_inverse = linalg_ops.matrix_triangular_solve(l, + linalg_ops.eye( + num_rows, + batch_shape=batch_shape, + dtype=l.dtype)) return _GradWithInverseL(l, l_inverse, grad) @@ -281,75 +285,94 @@ class CholeskyGradTest(test.TestCase): class CholeskyBenchmark(test.Benchmark): - sizes = [ - (4, 4), (16, 16), (256, 256), (1024, 1024), (2048, 2048), - (513, 2, 2), (513, 8, 8), (4, 513, 2, 2) + shapes = [ + (4, 4), + (10, 10), + (16, 16), + (101, 101), + (256, 256), + (1000, 1000), + (1024, 1024), + (2048, 2048), + (513, 2, 2), + (513, 8, 8), + (513, 256, 256), + (4, 513, 2, 2), ] - def _GenerateData(self, size): - batch_shape = size[:-2] - size = size[-2:] - assert size[0] == size[1] - n = size[0] - data = np.ones(size).astype(np.float32) / (2.0 * n) + np.diag( - np.ones(n).astype(np.float32)) - return np.tile(data, batch_shape + (1, 1)) + def _GenerateMatrix(self, shape): + batch_shape = shape[:-2] + shape = shape[-2:] + assert shape[0] == shape[1] + n = shape[0] + matrix = np.ones(shape).astype(np.float32) / ( + 2.0 * n) + np.diag(np.ones(n).astype(np.float32)) + return np.tile(matrix, batch_shape + (1, 1)) def benchmarkCholeskyOp(self): - for size in self.sizes: - data = self._GenerateData(size) - + for shape in self.shapes: with ops.Graph().as_default(), \ session.Session() as sess, \ ops.device("/cpu:0"): - l = linalg_ops.cholesky(data) + matrix = variables.Variable(self._GenerateMatrix(shape)) + l = linalg_ops.cholesky(matrix) + variables.global_variables_initializer().run() self.run_op_benchmark( - sess, control_flow_ops.group(l,), + sess, + control_flow_ops.group( + l,), min_iters=25, - name="cholesky_cpu_{size}".format(size=size)) + name="cholesky_cpu_{shape}".format(shape=shape)) if test.is_gpu_available(True): with ops.Graph().as_default(), \ session.Session() as sess, \ ops.device("/device:GPU:0"): - l = linalg_ops.cholesky(data) + matrix = variables.Variable(self._GenerateMatrix(shape)) + l = linalg_ops.cholesky(matrix) + variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group( l,), min_iters=25, - name="cholesky_gpu_{size}".format(size=size)) + name="cholesky_gpu_{shape}".format(shape=shape)) def benchmarkGradVariants(self): + def _BenchmarkGrad(grad_fn, name, device): - for size in self.sizes: - data = self._GenerateData(size) - l = np.linalg.cholesky(data) - grad_data = np.random.randn(*data.shape).astype(np.float32) + for shape in self.shapes: + matrix = self._GenerateMatrix(shape) with ops.Graph().as_default(), \ session.Session() as sess, \ ops.device(device): - grad = grad_fn(l, grad_data) + l = variables.Variable(np.linalg.cholesky(matrix)) + grad_matrix = variables.Variable( + np.random.randn(*matrix.shape).astype(np.float32)) + grad = grad_fn(l, grad_matrix) + variables.global_variables_initializer().run() self.run_op_benchmark( - sess, control_flow_ops.group(grad,), + sess, + control_flow_ops.group( + grad,), min_iters=25, - name="{name}_{dev}_{size}".format( - name=name, dev=grad.device, size=size)) + name="{name}_{dev}_{shape}".format( + name=name, dev=grad.device, shape=shape)) if test.is_gpu_available(True): - _BenchmarkGrad( - MatrixInverseCompositeGrad, "composite_matrix_inverse", "/device:GPU:0") - _BenchmarkGrad( - TriAngInvCompositeGrad, "composite_tri_ang_inverse", "/device:GPU:0") - _BenchmarkGrad( - TriAngSolveCompositeGrad, "composite_triangular_solve", "/device:GPU:0") + _BenchmarkGrad(MatrixInverseCompositeGrad, "composite_matrix_inverse", + "/device:GPU:0") + _BenchmarkGrad(TriAngInvCompositeGrad, "composite_tri_ang_inverse", + "/device:GPU:0") + _BenchmarkGrad(TriAngSolveCompositeGrad, "composite_triangular_solve", + "/device:GPU:0") - _BenchmarkGrad( - MatrixInverseCompositeGrad, "composite_matrix_inverse", "/cpu:0") - _BenchmarkGrad( - TriAngInvCompositeGrad, "composite_tri_ang_inverse", "/cpu:0") - _BenchmarkGrad( - TriAngSolveCompositeGrad, "composite_triangular_solve", "/cpu:0") + _BenchmarkGrad(MatrixInverseCompositeGrad, "composite_matrix_inverse", + "/cpu:0") + _BenchmarkGrad(TriAngInvCompositeGrad, "composite_tri_ang_inverse", + "/cpu:0") + _BenchmarkGrad(TriAngSolveCompositeGrad, "composite_triangular_solve", + "/cpu:0") _BenchmarkGrad(SpecializedGrad, "specialized", "/cpu:0") diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py index b9fc1104056..4f07322d61c 100644 --- a/tensorflow/python/kernel_tests/determinant_op_test.py +++ b/tensorflow/python/kernel_tests/determinant_op_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -130,49 +131,55 @@ class DeterminantOpTest(test.TestCase): class MatrixDeterminantBenchmark(test.Benchmark): - sizes = [ + shapes = [ (4, 4), + (10, 10), (16, 16), + (101, 101), (256, 256), + (1000, 1000), (1024, 1024), + (2048, 2048), (513, 4, 4), (513, 16, 16), (513, 256, 256), ] - def _GenerateData(self, size): - batch_shape = size[:-2] - size = size[-2:] - assert size[0] == size[1] - n = size[0] - data = np.ones(size).astype(np.float32) / ( + def _GenerateMatrix(self, shape): + batch_shape = shape[:-2] + shape = shape[-2:] + assert shape[0] == shape[1] + n = shape[0] + matrix = np.ones(shape).astype(np.float32) / ( 2.0 * n) + np.diag(np.ones(n).astype(np.float32)) - return np.tile(data, batch_shape + (1, 1)) + return variables.Variable(np.tile(matrix, batch_shape + (1, 1))) def benchmarkMatrixDeterminantOp(self): - for size in self.sizes: - data = self._GenerateData(size) - + for shape in self.shapes: with ops.Graph().as_default(), session.Session() as sess, ops.device( "/cpu:0"): - d = linalg_ops.matrix_determinant(data) + matrix = self._GenerateMatrix(shape) + d = linalg_ops.matrix_determinant(matrix) + variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group( d,), min_iters=25, - name="matrix_determinant_cpu_{size}".format(size=size)) + name="matrix_determinant_cpu_{shape}".format(shape=shape)) if test.is_gpu_available(True): with ops.Graph().as_default(), session.Session() as sess, ops.device( "/gpu:0"): - d = linalg_ops.matrix_determinant(data) + matrix = self._GenerateMatrix(shape) + d = linalg_ops.matrix_determinant(matrix) + variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group( d,), min_iters=25, - name="matrix_determinant_gpu_{size}".format(size=size)) + name="matrix_determinant_gpu_{shape}".format(shape=shape)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py index 601084c8307..7343a02c2cd 100644 --- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -137,52 +138,58 @@ class InverseOpTest(test.TestCase): class MatrixInverseBenchmark(test.Benchmark): - sizes = [ + shapes = [ (4, 4), + (10, 10), (16, 16), + (101, 101), (256, 256), + (1000, 1000), (1024, 1024), + (2048, 2048), (513, 4, 4), (513, 16, 16), (513, 256, 256), ] - def _GenerateData(self, size): - batch_shape = size[:-2] - size = size[-2:] - assert size[0] == size[1] - n = size[0] - data = np.ones(size).astype(np.float32) / ( + def _GenerateMatrix(self, shape): + batch_shape = shape[:-2] + shape = shape[-2:] + assert shape[0] == shape[1] + n = shape[0] + matrix = np.ones(shape).astype(np.float32) / ( 2.0 * n) + np.diag(np.ones(n).astype(np.float32)) - return np.tile(data, batch_shape + (1, 1)) + return variables.Variable(np.tile(matrix, batch_shape + (1, 1))) def benchmarkMatrixInverseOp(self): for adjoint in False, True: - for size in self.sizes: - data = self._GenerateData(size) - + for shape in self.shapes: with ops.Graph().as_default(), \ session.Session() as sess, \ ops.device("/cpu:0"): - inv = linalg_ops.matrix_inverse(data, adjoint=adjoint) + matrix = self._GenerateMatrix(shape) + inv = linalg_ops.matrix_inverse(matrix, adjoint=adjoint) + variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group(inv), min_iters=25, - name="matrix_inverse_cpu_{size}_{adjoint}".format( - size=size, adjoint="adjoint" if adjoint else "noadjoint")) + name="matrix_inverse_cpu_{shape}_adjoint_{adjoint}".format( + shape=shape, adjoint=adjoint)) if test.is_gpu_available(True): with ops.Graph().as_default(), \ session.Session() as sess, \ ops.device("/gpu:0"): - inv = linalg_ops.matrix_inverse(data, adjoint=adjoint) + matrix = self._GenerateMatrix(shape) + inv = linalg_ops.matrix_inverse(matrix, adjoint=adjoint) + variables.global_variables_initializer().run() self.run_op_benchmark( sess, control_flow_ops.group(inv), min_iters=25, - name="matrix_inverse_gpu_{size}_{adjoint}".format( - size=size, adjoint="adjoint" if adjoint else "noadjoint")) + name="matrix_inverse_gpu_{shape}_adjoint_{adjoint}".format( + shape=shape, adjoint=adjoint)) if __name__ == "__main__": From 513def0bb27e4a7c29f6ff533d8ca150b2ab78b4 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 17 Aug 2017 17:48:53 -0700 Subject: [PATCH 09/70] Fixed BuildOpInfoWithoutDevice PiperOrigin-RevId: 165653933 --- tensorflow/core/grappler/costs/BUILD | 20 ++- tensorflow/core/grappler/costs/utils.cc | 25 +--- tensorflow/core/grappler/costs/utils_test.cc | 150 +++++++++++++++++++ 3 files changed, 177 insertions(+), 18 deletions(-) create mode 100644 tensorflow/core/grappler/costs/utils_test.cc diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index ea1990c0b19..f2c13d2b132 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -141,6 +141,24 @@ tf_cuda_library( ], ) +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":utils", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "cost_estimator", hdrs = ["cost_estimator.h"], @@ -170,7 +188,7 @@ cc_test( srcs = ["virtual_placer_test.cc"], deps = [ ":virtual_placer", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:core_cpu", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 4135d9b3313..db36f97500e 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -70,11 +70,12 @@ static std::vector ExtractTensors(const AttrValue& attr_value) { return tensors; } +// Annotate the op_info inputs with extra information when possible (e.g. the +// input value if it's known statically). static void ExtractExtraProperties( const NodeDef& node, const std::unordered_map& name_to_node, - std::vector* extra_inputs, - protobuf::Map* attr_map) { + OpInfo* op_info) { OpRegistry* op_registry = OpRegistry::Global(); const OpDef* op_def = nullptr; auto s = op_registry->LookUpOpDef(node.op(), &op_def); @@ -102,11 +103,8 @@ static void ExtractExtraProperties( if (tensors.empty()) continue; const TensorProto& t = tensors[0]; - OpInfo::TensorProperties input; - input.set_dtype(t.dtype()); - *(input.mutable_shape()) = t.tensor_shape(); - *(input.mutable_value()) = t; - extra_inputs->push_back(input); + OpInfo::TensorProperties* input = op_info->mutable_inputs(i); + *(input->mutable_value()) = t; // For filename input, the file size can also be useful. if (op_def && i < op_def->input_arg_size() && @@ -129,7 +127,7 @@ static void ExtractExtraProperties( AttrValue attr; attr.set_i(stat.length); string attr_key = strings::StrCat("input_", i, "_filesize"); - (*attr_map)[attr_key] = attr; + (*op_info->mutable_attr())[attr_key] = attr; } } @@ -140,7 +138,7 @@ static void ExtractExtraProperties( string new_key = strings::StrCat("parent_", i, "_op"); AttrValue attr; attr.set_s(input_node->op()); - (*attr_map)[new_key] = attr; + (*op_info->mutable_attr())[new_key] = attr; // TODO(yuefengz): Only parent node's op name is copied. Copy inputs // and attributes when necessary. } @@ -212,14 +210,7 @@ OpInfo BuildOpInfoWithoutDevice( for (auto& input : inputs) { *op_info.add_inputs() = input; } - - std::vector extra_inputs; - ExtractExtraProperties(node, name_to_node, &extra_inputs, - op_info.mutable_attr()); - for (auto& input : extra_inputs) { - *op_info.add_inputs() = input; - } - + ExtractExtraProperties(node, name_to_node, &op_info); return op_info; } diff --git a/tensorflow/core/grappler/costs/utils_test.cc b/tensorflow/core/grappler/costs/utils_test.cc new file mode 100644 index 00000000000..bdcb156c4e3 --- /dev/null +++ b/tensorflow/core/grappler/costs/utils_test.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/utils.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class UtilsTest : public ::testing::Test { + public: + void CreateConstOp(const string& name, std::initializer_list dims, + NodeDef* node) { + Tensor tensor(DT_FLOAT, TensorShape(dims)); + for (int64 i = 0; i < tensor.NumElements(); ++i) { + tensor.flat()(i) = i / 10.0f; + } + TF_CHECK_OK(NodeDefBuilder(name, "Const") + .Attr("dtype", DT_FLOAT) + .Attr("value", tensor) + .Finalize(node)); + } + + void CreateConstSizesOp(const string& name, const std::vector& sizes, + NodeDef* node) { + TensorShape shape; + shape.AddDim(sizes.size()); + Tensor tensor(DT_INT32, shape); + for (int64 i = 0; i < tensor.NumElements(); ++i) { + tensor.flat()(i) = sizes[i]; + } + TF_CHECK_OK(NodeDefBuilder(name, "Const") + .Attr("dtype", DT_INT32) + .Attr("value", tensor) + .Finalize(node)); + } +}; + +TEST_F(UtilsTest, ConvOpInfo) { + int batch = 32; + int rows = 7; + int cols = 9; + int filter_rows = 3; + int filter_cols = 3; + int out_rows = 7; + int out_cols = 9; + int in_depth = 3; + int out_depth = 5; + int stride = 1; + + std::unordered_map name_to_node; + GraphDef graph; + NodeDef* input = graph.add_node(); + name_to_node["input"] = input; + CreateConstOp("input", {batch, rows, cols, in_depth}, input); + NodeDef* filter = graph.add_node(); + name_to_node["filter"] = filter; + CreateConstOp("filter", {filter_rows, filter_cols, in_depth, out_depth}, + filter); + NodeDef* output_backprop = graph.add_node(); + name_to_node["output_backprop"] = output_backprop; + CreateConstOp("output_backprop", {batch, out_rows, out_cols, out_depth}, + output_backprop); + NodeDef* input_sizes = graph.add_node(); + name_to_node["input_sizes"] = input; + CreateConstSizesOp("input_sizes", + std::vector({batch, rows, cols, in_depth}), + input_sizes); + NodeDef* filter_sizes = graph.add_node(); + name_to_node["filter_sizes"] = filter_sizes; + CreateConstSizesOp( + "filter_sizes", + std::vector({filter_rows, filter_cols, in_depth, out_depth}), + filter_sizes); + + TensorShape paddings_shape({4, 2}); + Tensor paddings_tensor(DT_INT32, paddings_shape); + for (int64 i = 0; i < paddings_tensor.NumElements(); ++i) { + paddings_tensor.flat()(i) = 0; + } + TF_CHECK_OK(NodeDefBuilder("paddings", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", paddings_tensor) + .Finalize(graph.add_node())); + + // Now add the convolution op + NodeDef* conv = graph.add_node(); + TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2D") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Finalize(conv)); + + NodeDef* conv_bp_in = graph.add_node(); + TF_CHECK_OK(NodeDefBuilder("conv2d_bp_in", "Conv2DBackpropInput") + .Input("input_sizes", 0, DT_INT32) + .Input("filter", 0, DT_FLOAT) + .Input("output_backprop", 0, DT_FLOAT) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Finalize(conv_bp_in)); + + NodeDef* conv_bp_filter = graph.add_node(); + TF_CHECK_OK(NodeDefBuilder("conv2d_bp_filter", "Conv2DBackpropFilter") + .Input("input", 0, DT_FLOAT) + .Input("filter_sizes", 0, DT_INT32) + .Input("output_backprop", 0, DT_FLOAT) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Finalize(conv_bp_filter)); + + for (const auto& node : graph.node()) { + if (node.name().find("conv2d") != 0) { + continue; + } + std::vector inputs; + inputs.resize(node.input_size()); + OpInfo info = BuildOpInfoWithoutDevice(node, name_to_node, inputs); + if (node.name() == "conv2d") { + EXPECT_EQ(2, info.inputs_size()); + } else if (node.name() == "conv2dbp_in") { + EXPECT_EQ(3, info.inputs_size()); + } else if (node.name() == "conv2d_bp_filter") { + EXPECT_EQ(3, info.inputs_size()); + } + } +} + +} // end namespace grappler +} // end namespace tensorflow From a1225879cdedae7f2de24030a9c072a516d97040 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Thu, 17 Aug 2017 17:55:08 -0700 Subject: [PATCH 10/70] [XLA] Propagate error code in computation replay tool. PiperOrigin-RevId: 165654497 --- tensorflow/compiler/xla/tools/replay_computation.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 6228ca34c08..735c66e2d3e 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -82,9 +82,10 @@ StatusOr> ReplayComputation( return client->ExecuteAndTransfer(computation, execute_arguments); } -void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { +int RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { Client* client = ClientLibrary::LocalClientOrDie(); tensorflow::Env* env = tensorflow::Env::Default(); + int exit_status = EXIT_SUCCESS; for (char* arg : args) { SessionModule module; TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); @@ -93,6 +94,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); + exit_status = EXIT_FAILURE; continue; } std::unique_ptr result = result_status.ConsumeValueOrDie(); @@ -105,6 +107,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { Literal(module.result()).ToString().c_str()); } } + return exit_status; } } // namespace tools @@ -126,6 +129,5 @@ int main(int argc, char** argv) { tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - xla::tools::RealMain(args, use_fake_data); - return 0; + return xla::tools::RealMain(args, use_fake_data); } From f0da8bf56ba1b625d53b760683bc44f67e204199 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Aug 2017 17:56:51 -0700 Subject: [PATCH 11/70] [Rematerialization] Reconsider to remat operations with control dependencies We added a conservartive logic to not rematerialize operations with control dependencies since the rematerialized operations could result in undesired ordering. However, we now realize that when we remat an operation, we also copy the dependencies of them, which guarantees the rematerialized operation has the same constraint as the original operation. PiperOrigin-RevId: 165654629 --- .../xla/service/hlo_rematerialization.cc | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 9f65f1b8512..a0e5bb7911b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -55,16 +55,6 @@ namespace { // Returns true if the given instruction is rematerializable. bool IsRematerializable(const HloInstruction* instruction) { - // Conservatively, don't rematerialize instruction with control - // dependencies. For one, control dependencies are added to prevent - // interference of aliased buffers (say, in while bodies) and - // rematerialization is ignorant of liveness and may break the intended - // ordering. - if (!instruction->control_predecessors().empty() || - !instruction->control_successors().empty()) { - return false; - } - // Don't rematerialize instructions with side effects or instructions which // cannot be cloned safely. switch (instruction->opcode()) { @@ -906,6 +896,19 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, continue; } + // If any of the candidate's control successor has been placed, we need to + // skip this candidate. Otherwise we will violate control dependency. + bool control_successor_placed = + std::any_of(candidate->control_successors().begin(), + candidate->control_successors().end(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + }); + + if (control_successor_placed) { + continue; + } + const int64 memory_reduced = memory_tracker.MemoryReducedIfRematerialized(item); @@ -1047,6 +1050,15 @@ StatusOr HloRematerialization::RematerializeComputation( HloInstruction* remat = computation->AddInstruction(best->Clone(/*suffix=*/"remat")); + + // Add control dependencies to the new operation. + for (auto successor : best->control_successors()) { + TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); + } + for (auto predecessor : best->control_predecessors()) { + TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); + } + Item* remat_item = instruction_list.CreateItem(remat); // Replace each remaining use of 'best' with the rematerialization. @@ -1082,6 +1094,15 @@ StatusOr HloRematerialization::RematerializeComputation( } } } + // Insert rematerialized instruction before any of its successors to + // preserve ordering regarding control dependency. + for (auto successor : remat->control_successors()) { + Item* successor_item = instruction_list.GetItem(successor); + // Assert to make sure we never remat an operation with control + // successor already placed. + CHECK(!successor_item->placed); + place_before.push_back(successor_item); + } instruction_list.InsertBeforeInstructions(remat_item, place_before); // If the rematerialized instruction is dead then rematerialization is From 7359fec792e4efec1670a12332bb524a5608b215 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Aug 2017 18:04:58 -0700 Subject: [PATCH 12/70] Implement Batchnorm Inference by expanding them into smaller ops. 1. Add batch norm inference support in batchnorm_rewriter 2. Connect xla's batchnorm inference to tf's FusedBatchNorm RELNOTES: n/a PiperOrigin-RevId: 165655351 --- .../compiler/tests/fused_batchnorm_test.py | 33 ++++ .../compiler/tf2xla/kernels/batch_norm_op.cc | 42 ++--- .../xla/client/computation_builder.cc | 26 +++- .../xla/service/batchnorm_rewriter.cc | 98 +++++++++++- .../compiler/xla/service/batchnorm_rewriter.h | 3 + .../xla/service/batchnorm_rewriter_test.cc | 2 + .../compiler/xla/service/cpu/cpu_compiler.cc | 1 + .../compiler/xla/service/dfs_hlo_visitor.h | 3 + .../service/dfs_hlo_visitor_with_default.h | 4 + .../compiler/xla/service/gpu/gpu_compiler.cc | 1 + .../compiler/xla/service/hlo_cost_analysis.cc | 6 + .../compiler/xla/service/hlo_cost_analysis.h | 1 + .../compiler/xla/service/hlo_graph_dumper.cc | 1 + .../compiler/xla/service/hlo_instruction.cc | 26 ++++ .../compiler/xla/service/hlo_instruction.h | 6 + tensorflow/compiler/xla/service/hlo_opcode.cc | 2 + tensorflow/compiler/xla/service/hlo_opcode.h | 1 + .../xla/service/instruction_fusion.cc | 1 + tensorflow/compiler/xla/service/service.cc | 4 + .../compiler/xla/service/shape_inference.cc | 144 ++++++++++++++++++ .../compiler/xla/service/shape_inference.h | 7 + .../compiler/xla/service/user_computation.cc | 100 ++++++++++++ .../compiler/xla/service/user_computation.h | 4 + .../xla/tests/batch_normalization_test.cc | 103 +++++++++++++ tensorflow/compiler/xla/xla_data.proto | 13 +- 25 files changed, 605 insertions(+), 27 deletions(-) diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index f8e9fc92681..936fcf8b6be 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -63,6 +63,39 @@ class FusedBatchNormTest(XLATestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset + def testInference(self): + x_shape = [2, 2, 6, 2] + scale_shape = [2] + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + + offset_val = np.random.random_sample(scale_shape).astype(np.float32) + data_format = "NHWC" + with self.test_session() as sess, self.test_scope(): + # To avoid constant folding + t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") + scale = array_ops.placeholder(np.float32, shape=[2], name="scale") + offset = array_ops.placeholder(np.float32, shape=[2], name="offset") + epsilon = 0.001 + y_ref, mean_ref, var_ref = self._reference_training( + x_val, scale_val, offset_val, epsilon, data_format) + y, mean, variance = nn.fused_batch_norm( + t_val, + scale, + offset, + mean=mean_ref, + variance=var_ref, + epsilon=epsilon, + data_format=data_format, + is_training=False) + + y_val, _, _ = sess.run( + [y, mean, + variance], {t_val: x_val, + scale: scale_val, + offset: offset_val}) + self.assertAllClose(y_val, y_ref, atol=1e-3) + def _testLearning(self, use_gradient_checker): x_shape = [2, 2, 6, 2] scale_shape = [2] diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 3f23e459b98..9d2703bf952 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -39,28 +39,36 @@ class FusedBatchNormOp : public XlaOpKernel { errors::InvalidArgument("Not supported format")); feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format); } - // TODO(b/62843645): Implement BatchNormInference. - OP_REQUIRES( - ctx, is_training_, - errors::InvalidArgument("Fused batch normalization for inference is " - "not supported yet on XLA backend.")); } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( - ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, feature_index_); + if (is_training_) { + xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( + ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, + feature_index_); - // In training mode, outputs the normalized value as well as the calculated - // mean and variance. - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + // In training mode, outputs the normalized value as well as the + // calculated mean and variance. + for (int i = 0; i < 3; i++) { + ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + } + // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved + // space 1 & 2". They are used to pass the per-batch mean and + // variance to the gradient. Here we maintain the same behavior by setting + // them to the mean and variance calculated by BatchNormTraining. + ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); + ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + } else { + xla::ComputationDataHandle output = ctx->builder()->BatchNormInference( + ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3), + ctx->Input(4), epsilon_, feature_index_); + ctx->SetOutput(0, output); + // Directly send input to output as mean and variance in inference mode. + ctx->SetOutput(1, ctx->Input(3)); + ctx->SetOutput(2, ctx->Input(4)); + ctx->SetOutput(3, ctx->Input(3)); + ctx->SetOutput(4, ctx->Input(4)); } - // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved - // space 1 & 2". They are used to pass the per-batch mean and - // variance to the gradient. Here we maintain the same behavior by setting - // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); } private: diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index e6ffc4f98de..30afaed7323 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -1477,9 +1477,29 @@ ComputationDataHandle ComputationBuilder::BatchNormInference( const ComputationDataHandle& operand, const ComputationDataHandle& scale, const ComputationDataHandle& offset, const ComputationDataHandle& mean, const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - // TODO(b/62843645): Implement BatchNormInference. - NoteError(Unimplemented("BatchNormInference is not implemented yet.")); - return ComputationDataHandle(); + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + BatchNormInferenceRequest request; + *request.mutable_operand() = operand; + *request.mutable_scale() = scale; + *request.mutable_offset() = offset; + *request.mutable_mean() = mean; + *request.mutable_variance() = variance; + request.set_epsilon(epsilon); + request.set_feature_index(feature_index); + + OpRequest op_request; + *op_request.mutable_batch_norm_inference_request() = request; + *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); + + OpResponse response; + + VLOG(2) << "making BatchNormInference request"; + + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); } ComputationDataHandle ComputationBuilder::BatchNormGrad( diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index 721d99301a1..41d32d0c8b1 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -56,11 +56,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { Status HandleBatchNormTraining(HloInstruction* batch_norm) override; + Status HandleBatchNormInference(HloInstruction* batch_norm) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm) override; // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_grad_op, bool use_fusion); + bool rewrite_inference_op, bool rewrite_grad_op, + bool use_fusion); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -70,9 +73,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { private: explicit BatchNormRewriterVisitor(HloComputation* computation, bool rewrite_training_op, + bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) : computation_(computation), rewrite_training_op_(rewrite_training_op), + rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} @@ -94,6 +99,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { HloComputation* computation_; bool rewrite_training_op_; + bool rewrite_inference_op_; bool rewrite_grad_op_; bool use_fusion_; @@ -126,11 +132,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { bool BatchNormRewriterVisitor::Run(HloComputation* computation, bool rewrite_training_op, + bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) { - BatchNormRewriterVisitor visitor(computation, - /*rewrite_training_op=*/rewrite_training_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_fusion=*/use_fusion); + BatchNormRewriterVisitor visitor( + computation, + /*rewrite_training_op=*/rewrite_training_op, + /*rewrite_inference_op=*/rewrite_inference_op, + /*rewrite_grad_op=*/rewrite_grad_op, + /*use_fusion=*/use_fusion); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -268,6 +277,82 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( return Status::OK(); } +Status BatchNormRewriterVisitor::HandleBatchNormInference( + HloInstruction* batch_norm) { + if (!rewrite_inference_op_) { + return Status::OK(); + } + // Expand batch norm inference into smaller HLO ops. + HloInstruction* operand = batch_norm->mutable_operand(0); + const Shape operand_shape = operand->shape(); + int64 feature_index = batch_norm->feature_index(); + + HloInstruction* scale = batch_norm->mutable_operand(1); + HloInstruction* offset = batch_norm->mutable_operand(2); + HloInstruction* mean = batch_norm->mutable_operand(3); + HloInstruction* var = batch_norm->mutable_operand(4); + const Shape feature_shape = scale->shape(); + + auto epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + + std::vector dimensions_without_feature; + + for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + if (i != feature_index) { + dimensions_without_feature.push_back(i); + } + } + + auto scale_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); + + auto offset_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); + + auto mean_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); + + auto var_broadcasted = computation_->AddInstruction( + HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + + // Var[X] + epsilon. + auto var_add_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + + auto neg_half = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + + // 1 / Sqrt[Var[X] + epsilon]. + auto rsqrt_var_add_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + + // X - E[X]. + auto operand_minus_mean = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon]. + auto normalized = computation_->AddInstruction( + HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. + auto scaled_normalized = + computation_->AddInstruction(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + + // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. + auto shifted_normalized = HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); + + TF_CHECK_OK( + ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); + return Status::OK(); +} + Status BatchNormRewriterVisitor::HandleBatchNormGrad( HloInstruction* batch_norm) { // Use the following formulas to calculate gradients: @@ -457,7 +542,8 @@ StatusOr BatchNormRewriter::Run(HloModule* module) { } for (auto& comp : computations) { if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_, - rewrite_grad_op_, use_fusion_)) { + rewrite_inference_op_, rewrite_grad_op_, + use_fusion_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_rewriter.h index d3ffb31032e..f601741d964 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.h @@ -30,8 +30,10 @@ class BatchNormRewriter : public HloPassInterface { public: // When use_fusion is set, a multi-output fusion node is created. BatchNormRewriter(bool rewrite_training_op = false, + bool rewrite_inference_op = false, bool rewrite_grad_op = false, bool use_fusion = true) : rewrite_training_op_(rewrite_training_op), + rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} ~BatchNormRewriter() = default; @@ -43,6 +45,7 @@ class BatchNormRewriter : public HloPassInterface { private: bool rewrite_training_op_; + bool rewrite_inference_op_; bool rewrite_grad_op_; bool use_fusion_; }; diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc index cc8dffcda51..07775623e75 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc @@ -64,6 +64,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) { HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); @@ -105,6 +106,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) { HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index eca9b0f4bef..8a37c8108ea 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -260,6 +260,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { pipeline.AddPass>("simplification"); pass.AddPass( /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); pass.AddPass( diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index e450b31ff18..4baa56658f7 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -228,6 +228,9 @@ class DfsHloVisitor { virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0; + virtual Status HandleBatchNormInference( + HloInstruction* batchNormInference) = 0; + virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0; // Invoked to inform the visitor that the traversal has completed, and that diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index c447165cecc..10f8ae9b044 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -54,6 +54,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { return DefaultAction(hlo); } + Status HandleBatchNormInference(HloInstruction* hlo) override { + return DefaultAction(hlo); + } + Status HandleBatchNormGrad(HloInstruction* hlo) override { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 2a7486af881..cd913a4b5d6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -135,6 +135,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, // instead. pass.AddPass( /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); pass.AddPass( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index d113ca2a76b..9dbde0ec243 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -374,6 +374,12 @@ Status HloCostAnalysis::HandleBatchNormTraining( return Status::OK(); } +Status HloCostAnalysis::HandleBatchNormInference( + HloInstruction* batchNormInference) { + // TODO(b/62294698): Implement cost analysis for batch-norm-inference. + return Status::OK(); +} + Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batchNormGrad) { // TODO(b/62294698): Implement cost analysis for batch-norm-grad. return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index ec48c8a0fd8..6d8fdfa64b5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -89,6 +89,7 @@ class HloCostAnalysis : public DfsHloVisitor { tensorflow::gtl::ArraySlice dimensions, HloComputation* function_handle) override; Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override; + Status HandleBatchNormInference(HloInstruction* batchNormInference) override; Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index d1c31963665..38b1291d440 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -742,6 +742,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kParameter: return kOrange; case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kReduce: case HloOpcode::kSelectAndScatter: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 825f3f8f60e..fb9dbd64216 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -406,6 +406,23 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBatchNormInference( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(scale); + instruction->AppendOperand(offset); + instruction->AppendOperand(mean); + instruction->AppendOperand(variance); + instruction->epsilon_ = epsilon; + instruction->feature_index_ = feature_index; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* mean, @@ -1065,6 +1082,12 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return CreateBatchNormTraining(shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), feature_index()); + + case HloOpcode::kBatchNormInference: + CHECK_EQ(new_operands.size(), 5); + return CreateBatchNormInference( + shape, new_operands[0], new_operands[1], new_operands[2], + new_operands[3], new_operands[4], epsilon(), feature_index()); case HloOpcode::kInfeed: CHECK_EQ(new_operands.size(), 0); return CreateInfeed(shape, infeed_config()); @@ -1355,6 +1378,7 @@ bool HloInstruction::IdenticalSlowPath( ShapeUtil::Compatible(shape(), other.shape()); case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: return feature_index() == other.feature_index() && epsilon() == other.epsilon(); @@ -1952,6 +1976,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleAbs(this, operands_[0]); case HloOpcode::kBatchNormTraining: return visitor->HandleBatchNormTraining(this); + case HloOpcode::kBatchNormInference: + return visitor->HandleBatchNormInference(this); case HloOpcode::kBatchNormGrad: return visitor->HandleBatchNormGrad(this); case HloOpcode::kSign: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f2005380d8e..d246720b3cf 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -224,6 +224,12 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index); + // Creates a batch-norm-inference instruction. + static std::unique_ptr CreateBatchNormInference( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index); + // Creates a batch-norm-grad instruction. static std::unique_ptr CreateBatchNormGrad( const Shape& shape, HloInstruction* operand, HloInstruction* scale, diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 3888f757ada..314512d0a8d 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) { return "add"; case HloOpcode::kBatchNormTraining: return "batch-norm-training"; + case HloOpcode::kBatchNormInference: + return "batch-norm-inference"; case HloOpcode::kBatchNormGrad: return "batch-norm-grad"; case HloOpcode::kBitcast: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 8a6376b2d1c..c4d5efad903 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -31,6 +31,7 @@ enum class HloOpcode { kAbs, kAdd, kBatchNormTraining, + kBatchNormInference, kBatchNormGrad, kBitcast, kBroadcast, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 4333db17e75..edfcb0922d6 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -78,6 +78,7 @@ namespace xla { // Expensive instructions. case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kCall: case HloOpcode::kConvolution: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index ad2d5235f8d..d63d33ecb00 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1211,6 +1211,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddBatchNormTrainingInstruction( arg->batch_norm_training_request()); break; + case OpRequest::kBatchNormInferenceRequest: + handle_status = computation->AddBatchNormInferenceInstruction( + arg->batch_norm_inference_request()); + break; case OpRequest::kBatchNormGradRequest: handle_status = computation->AddBatchNormGradInstruction( arg->batch_norm_grad_request()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 2c2b0cca5fd..8eeb1cd5d20 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -885,6 +885,150 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( output_shape_for_mean_and_var}); } +/* static */ StatusOr ShapeInference::InferBatchNormInferenceShape( + const Shape& operand_shape, const Shape& offset_shape, + const Shape& scale_shape, const Shape& mean_shape, + const Shape& variance_shape, int64 feature_index) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + offset_shape, "offset input of batch norm inference")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + scale_shape, "scale input of batch norm inference")); + + TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(mean_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(variance_shape) == + tensorflow::Status::OK()); + + if (feature_index >= ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Expected feature_index of batch-norm-inference to be " + "smaller than the rank of operand_shape; " + "got feature_index %lld, and rank %lld", + feature_index, ShapeUtil::Rank(operand_shape)); + } + + if (feature_index < 0) { + return InvalidArgument( + "Expected feature_index of batch-norm-inference to " + "be a non-negative number, got %lld", + feature_index); + } + + if (ShapeUtil::Rank(operand_shape) < 1) { + return InvalidArgument( + "Expected the rank of operand to " + "batch-norm-inference to be at least 1; got %lld", + ShapeUtil::Rank(operand_shape)); + } + + if (ShapeUtil::Rank(offset_shape) != 1) { + return InvalidArgument( + "Offset input of batch-norm-inference must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(offset_shape)); + } + + if (ShapeUtil::Rank(scale_shape) != 1) { + return InvalidArgument( + "Scale input of batch-norm-inference must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(scale_shape)); + } + + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "The operand to batch-norm-inference must have a floating point " + "element type, but the shape is %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for " + "batch-norm-inference, " + "but the shape of offset factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(offset_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for " + "batch-norm-inference, " + "but the shape of scale factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(scale_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for " + "batch-norm-inference, " + "but the shape of mean is %s " + "and the shape of operand is %s", + PrimitiveType_Name(mean_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for " + "batch-norm-inference, " + "but the shape of variance is %s " + "and the shape of operand is %s", + PrimitiveType_Name(mean_shape.element_type()).c_str(), + PrimitiveType_Name(variance_shape.element_type()).c_str()); + } + + const int64 feature_count = operand_shape.dimensions(feature_index); + Shape output_shape_for_mean_and_var = + ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}); + + if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { + return InvalidArgument( + "The size of offset factor should be the same as feature count," + "but the size of offset factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(offset_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + return InvalidArgument( + "The size of scale factor should be the same as feature count," + "but the size of scale factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(scale_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { + return InvalidArgument( + "The size of mean should be the same as feature count," + "but the size of mean is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(mean_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) { + return InvalidArgument( + "The size of variance should be the same as feature count," + "but the size of variance is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(variance_shape, 0), feature_count); + } + + return operand_shape; +} + /* static */ StatusOr ShapeInference::InferBatchNormGradShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index f3f0176a434..5d55df92a91 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -71,6 +71,13 @@ class ShapeInference { const Shape& scale_shape, int64 feature_index); + // Infers the shape produced by InferBatchNormInference with the given + // operands. + static StatusOr InferBatchNormInferenceShape( + const Shape& operand_shape, const Shape& offset_shape, + const Shape& scale_shape, const Shape& mean_shape, + const Shape& variance_shape, int64 feature_index); + // Infers the shape produced by InferBatchNormGrad with the given operands. static StatusOr InferBatchNormGradShape(const Shape& operand_shape, const Shape& scale_shape, diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 3b280c97278..cfa5c98f593 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -507,6 +507,53 @@ UserComputation::AddBatchNormTrainingInstruction( return handle; } +StatusOr +UserComputation::AddBatchNormInferenceInstruction( + const BatchNormInferenceRequest& batch_norm_inference_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(batch_norm_inference_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* scale, + LookUpRequest(batch_norm_inference_request.scale())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* offset, + LookUpRequest(batch_norm_inference_request.offset())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* mean, + LookUpRequest(batch_norm_inference_request.mean())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* variance, + LookUpRequest(batch_norm_inference_request.variance())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferBatchNormInferenceShape( + operand->output_shape(), scale->output_shape(), + offset->output_shape(), mean->output_shape(), + variance->output_shape(), + batch_norm_inference_request.feature_index())); + + *request.mutable_output_shape() = inferred_shape; + + *request.mutable_output_handle() = handle; + + *request.mutable_request()->mutable_batch_norm_inference_request() = + batch_norm_inference_request; + + VLOG(1) << "AddBatchNormInferenceInstruction (" + << GetVersionedHandleInternal() << "), data handle " + << handle.handle() << ": " + << batch_norm_inference_request.ShortDebugString(); + + return handle; +} + StatusOr UserComputation::AddBatchNormGradInstruction( const BatchNormGradRequest& batch_norm_grad_request) { tensorflow::mutex_lock lock(mutex_); @@ -1678,6 +1725,25 @@ void ConstantVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kBatchNormInferenceRequest: { + const BatchNormInferenceRequest& batch_norm_inference_request = + request.request().batch_norm_inference_request(); + ConstantVisitor(session_computation, + batch_norm_inference_request.operand(), visited, + is_constant); + ConstantVisitor(session_computation, batch_norm_inference_request.scale(), + visited, is_constant); + ConstantVisitor(session_computation, + batch_norm_inference_request.offset(), visited, + is_constant); + ConstantVisitor(session_computation, batch_norm_inference_request.mean(), + visited, is_constant); + ConstantVisitor(session_computation, + batch_norm_inference_request.variance(), visited, + is_constant); + break; + } + case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); @@ -2119,6 +2185,18 @@ static void ForEachOperand( break; } + case OpRequest::kBatchNormInferenceRequest: { + const BatchNormInferenceRequest& batch_norm_inference_request = + request.request().batch_norm_inference_request(); + + apply(batch_norm_inference_request.operand()); + apply(batch_norm_inference_request.scale()); + apply(batch_norm_inference_request.offset()); + apply(batch_norm_inference_request.mean()); + apply(batch_norm_inference_request.variance()); + break; + } + case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); @@ -2647,6 +2725,28 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kBatchNormInferenceRequest: { + const BatchNormInferenceRequest& batch_norm_inference_request = + request.request().batch_norm_inference_request(); + HloInstruction* operand = + lookup_instruction(batch_norm_inference_request.operand()); + HloInstruction* scale = + lookup_instruction(batch_norm_inference_request.scale()); + HloInstruction* offset = + lookup_instruction(batch_norm_inference_request.offset()); + HloInstruction* mean = + lookup_instruction(batch_norm_inference_request.mean()); + HloInstruction* variance = + lookup_instruction(batch_norm_inference_request.variance()); + + hlo_instruction = + add_instruction(HloInstruction::CreateBatchNormInference( + request.output_shape(), operand, scale, offset, mean, variance, + batch_norm_inference_request.epsilon(), + batch_norm_inference_request.feature_index())); + break; + } + case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 36b1d34e05d..b779b1f76c8 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -89,6 +89,10 @@ class UserComputation { StatusOr AddBatchNormTrainingInstruction( const BatchNormTrainingRequest& batch_norm_training_request); + // Enqueues a batch norm inference instruction onto this user computation. + StatusOr AddBatchNormInferenceInstruction( + const BatchNormInferenceRequest& batch_norm_inference_request); + // Enqueues a batch norm grad instruction onto this user computation. StatusOr AddBatchNormGradInstruction( const BatchNormGradRequest& batch_norm_grad_request); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 34b3abb8c75..028d1251b45 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -306,6 +306,109 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) { ErrorSpec(0.01, 1)); } +XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { + float epsilon = 0.001; + ComputationBuilder builder(client_, TestName()); + const std::vector& bounds = GetParam().bounds; + Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); + input_array.FillRandom(GetParam().random_value_var, + GetParam().random_value_mean); + + const int64 feature_index = GetParam().feature_index; + const int64 num_elements_per_feature = + Product(bounds) / bounds[feature_index]; + const int64 feature_bound = bounds[feature_index]; + std::vector offset(feature_bound, 1); + std::vector scale(feature_bound, 2); + + auto input_squared = + ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); + std::vector reduce_dims; + for (int64 i = 0; i < static_cast(bounds.size()); ++i) { + if (i != feature_index) { + reduce_dims.push_back(i); + } + } + + auto sum = + ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + auto sum_squared = + ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + std::vector mean(feature_bound); + + for (int64 i = 0; i < feature_bound; ++i) { + mean[i] = sum[i] / num_elements_per_feature; + } + + std::vector mean_square(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + mean_square[i] = mean[i] * mean[i]; + } + + std::vector square_mean(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + square_mean[i] = sum_squared[i] / num_elements_per_feature; + } + + std::vector var(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + var[i] = square_mean[i] - mean_square[i]; + } + + Array4D mean4D = + *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); + auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); + auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); + auto offset4D = + *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); + + auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, + scale4D, offset4D, epsilon); + + auto offset_literal = Literal::CreateR1(offset); + auto scale_literal = Literal::CreateR1(scale); + auto mean_literal = Literal::CreateR1(mean); + auto var_literal = Literal::CreateR1(var); + auto input_literal = Literal::CreateR4FromArray4D(input_array); + + auto input_activations = + builder.Parameter(0, input_literal->shape(), "input"); + auto scale_activations = + builder.Parameter(1, scale_literal->shape(), "offset"); + auto offset_activations = + builder.Parameter(2, offset_literal->shape(), "scale"); + auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); + auto variance_activations = + builder.Parameter(4, var_literal->shape(), "variance"); + + Array4D expected = normalized; + + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + std::unique_ptr scale_data = + client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + std::unique_ptr offset_data = + client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + std::unique_ptr mean_data = + client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + std::unique_ptr variance_data = + client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + + builder.BatchNormInference(input_activations, scale_activations, + offset_activations, mean_activations, + variance_activations, epsilon, feature_index); + + ComputeAndCompareR4( + &builder, expected, + {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(), + variance_data.get()}, + ErrorSpec(0.01, 1)); +} + XLA_TEST_P(BatchNormTest, RandomizedGradTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 38e6675ab7e..185ca7e681c 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -491,6 +491,16 @@ message BatchNormTrainingRequest { int64 feature_index = 5; } +message BatchNormInferenceRequest { + ComputationDataHandle operand = 1; + ComputationDataHandle scale = 2; + ComputationDataHandle offset = 3; + ComputationDataHandle mean = 4; + ComputationDataHandle variance = 5; + float epsilon = 6; + int64 feature_index = 7; +} + message BatchNormGradRequest { ComputationDataHandle operand = 1; ComputationDataHandle scale = 2; @@ -813,7 +823,8 @@ message OpRequest { OutfeedRequest outfeed_request = 32; BatchNormTrainingRequest batch_norm_training_request = 35; BatchNormGradRequest batch_norm_grad_request = 37; - // Next: 38 + BatchNormInferenceRequest batch_norm_inference_request = 38; + // Next: 39 } } From 00594ecdd685a2b1eaebb3bcc6b9764bfd4ae5d6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Aug 2017 19:27:58 -0700 Subject: [PATCH 13/70] New landing page and leftnav for Programmer's Guide. PiperOrigin-RevId: 165660897 --- .../docs_src/programmers_guide/dims_types.md | 69 ----------------- tensorflow/docs_src/programmers_guide/faq.md | 51 +++++-------- .../docs_src/programmers_guide/index.md | 75 ++++++++++--------- .../docs_src/programmers_guide/leftnav_files | 6 +- 4 files changed, 63 insertions(+), 138 deletions(-) delete mode 100644 tensorflow/docs_src/programmers_guide/dims_types.md diff --git a/tensorflow/docs_src/programmers_guide/dims_types.md b/tensorflow/docs_src/programmers_guide/dims_types.md deleted file mode 100644 index 65b748d56ec..00000000000 --- a/tensorflow/docs_src/programmers_guide/dims_types.md +++ /dev/null @@ -1,69 +0,0 @@ -# Tensor Ranks, Shapes, and Types - -TensorFlow programs use a tensor data structure to represent all data. You can -think of a TensorFlow tensor as an n-dimensional array or list. -A tensor has a static type and dynamic dimensions. Only tensors may be passed -between nodes in the computation graph. - -## Rank - -In the TensorFlow system, tensors are described by a unit of dimensionality -known as *rank*. Tensor rank is not the same as matrix rank. Tensor rank -(sometimes referred to as *order* or *degree* or *n-dimension*) is the number -of dimensions of the tensor. For example, the following tensor (defined as a -Python list) has a rank of 2: - - t = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - -A rank two tensor is what we typically think of as a matrix, a rank one tensor -is a vector. For a rank two tensor you can access any element with the syntax -`t[i, j]`. For a rank three tensor you would need to address an element with -`t[i, j, k]`. - -Rank | Math entity | Python example ---- | --- | --- -0 | Scalar (magnitude only) | `s = 483` -1 | Vector (magnitude and direction) | `v = [1.1, 2.2, 3.3]` -2 | Matrix (table of numbers) | `m = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]` -3 | 3-Tensor (cube of numbers) | `t = [[[2], [4], [6]], [[8], [10], [12]], [[14], [16], [18]]]` -n | n-Tensor (you get the idea) | `....` - -## Shape - -The TensorFlow documentation uses three notational conventions to describe -tensor dimensionality: rank, shape, and dimension number. The following table -shows how these relate to one another: - -Rank | Shape | Dimension number | Example ---- | --- | --- | --- -0 | [] | 0-D | A 0-D tensor. A scalar. -1 | [D0] | 1-D | A 1-D tensor with shape [5]. -2 | [D0, D1] | 2-D | A 2-D tensor with shape [3, 4]. -3 | [D0, D1, D2] | 3-D | A 3-D tensor with shape [1, 4, 3]. -n | [D0, D1, ... Dn-1] | n-D | A tensor with shape [D0, D1, ... Dn-1]. - -Shapes can be represented via Python lists / tuples of ints, or with the -@{tf.TensorShape}. - -## Data types - -In addition to dimensionality, Tensors have a data type. You can assign any one -of the following data types to a tensor: - -Data type | Python type | Description ---- | --- | --- -`DT_FLOAT` | `tf.float32` | 32 bits floating point. -`DT_DOUBLE` | `tf.float64` | 64 bits floating point. -`DT_INT8` | `tf.int8` | 8 bits signed integer. -`DT_INT16` | `tf.int16` | 16 bits signed integer. -`DT_INT32` | `tf.int32` | 32 bits signed integer. -`DT_INT64` | `tf.int64` | 64 bits signed integer. -`DT_UINT8` | `tf.uint8` | 8 bits unsigned integer. -`DT_UINT16` | `tf.uint16` | 16 bits unsigned integer. -`DT_STRING` | `tf.string` | Variable length byte arrays. Each element of a Tensor is a byte array. -`DT_BOOL` | `tf.bool` | Boolean. -`DT_COMPLEX64` | `tf.complex64` | Complex number made of two 32 bits floating points: real and imaginary parts. -`DT_COMPLEX128` | `tf.complex128` | Complex number made of two 64 bits floating points: real and imaginary parts. -`DT_QINT8` | `tf.qint8` | 8 bits signed integer used in quantized Ops. -`DT_QINT32` | `tf.qint32` | 32 bits signed integer used in quantized Ops. -`DT_QUINT8` | `tf.quint8` | 8 bits unsigned integer used in quantized Ops. diff --git a/tensorflow/docs_src/programmers_guide/faq.md b/tensorflow/docs_src/programmers_guide/faq.md index 56486a48b7a..865016dc02d 100644 --- a/tensorflow/docs_src/programmers_guide/faq.md +++ b/tensorflow/docs_src/programmers_guide/faq.md @@ -53,10 +53,6 @@ TensorFlow assigns operations to devices, and the @{$deep_cnn$CIFAR-10 tutorial} for an example model that uses multiple GPUs. -#### What are the different types of tensors that are available? - -TensorFlow supports a variety of different data types and tensor shapes. See the -@{$dims_types$ranks, shapes, and types reference} for more details. ## Running a TensorFlow computation @@ -171,7 +167,8 @@ available. These operations allow you to build sophisticated @{$reading_data$input pipelines}, at the cost of making the TensorFlow computation somewhat more complicated. See the how-to documentation for -@{$reading_data#creating-threads-to-prefetch-using-queuerunner-objects$using `QueueRunner` objects to drive queues and readers} +@{$reading_data#creating-threads-to-prefetch-using-queuerunner-objects$using +`QueueRunner` objects to drive queues and readers} for more information on how to use them. ## Variables @@ -240,11 +237,6 @@ to encode the batch size as a Python constant, but instead to use a symbolic * Use @{tf.reduce_mean} instead of `tf.reduce_sum(...) / batch_size`. -* If you use - @{$reading_data#feeding$placeholders for feeding input}, - you can specify a variable batch dimension by creating the placeholder with - [`tf.placeholder(..., shape=[None, ...])`](../api_docs/python/io_ops.md#placeholder). The - `None` element of the shape corresponds to a variable-sized dimension. ## TensorBoard @@ -269,36 +261,33 @@ the flag --host=localhost. This should quiet any security warnings. ## Extending TensorFlow -See also the how-to documentation for +See the how-to documentation for @{$adding_an_op$adding a new operation to TensorFlow}. #### My data is in a custom format. How do I read it using TensorFlow? -There are two main options for dealing with data in a custom format. +There are three main options for dealing with data in a custom format. -The easier option is to write parsing code in Python that transforms the data -into a numpy array, then feed a -@{tf.placeholder} a tensor with -that data. See the documentation on -@{$reading_data#feeding$using placeholders for input} for -more details. This approach is easy to get up and running, but the parsing can -be a performance bottleneck. +The easiest option is to write parsing code in Python that transforms the data +into a numpy array. Then use @{tf.contrib.data.Dataset.from_tensor_slices} to +create an input pipeline from the in-memory data. -The more efficient option is to +If your data doesn't fit in memory, try doing the parsing in the Dataset +pipeline. Start with an appropriate file reader, like +@{tf.contrib.data.TextLineDataset}. Then convert the dataset by mapping +@{tf.contrib.data.Dataset.map$mapping} appropriate operations over it. +Prefer predefined TensorFlow operations such as @{tf.decode_raw}, +@{tf.decode_csv}, @{tf.parse_example}, or @{tf.image.decode_png}. + +If your data is not easily parsable with the built-in TensorFlow operations, +consider converting it, offline, to a format that is easily parsable, such +as ${tf.python_io.TFRecordWriter$`TFRecord`} format. + +The more efficient method to customize the parsing behavior is to @{$adding_an_op$add a new op written in C++} that parses your -data format. The -@{$new_data_formats$guide to handling new data formats} has +data format. The @{$new_data_formats$guide to handling new data formats} has more information about the steps for doing this. -#### How do I define an operation that takes a variable number of inputs? - -The TensorFlow op registration mechanism allows you to define inputs that are a -single tensor, a list of tensors with the same type (for example when adding -together a variable-length list of tensors), or a list of tensors with different -types (for example when enqueuing a tuple of tensors to a queue). See the -how-to documentation for -@{$adding_an_op#list-inputs-and-outputs$adding an op with a list of inputs or outputs} -for more details of how to define these different input types. ## Miscellaneous diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md index aa2e12504dd..214f3028e07 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/programmers_guide/index.md @@ -1,38 +1,45 @@ # Programmer's Guide The documents in this unit dive into the details of writing TensorFlow -code. This section begins with the following guides, each of which -explain a particular aspect of TensorFlow: +code. For TensorFlow 1.3, we revised this document extensively. +The units are now as follows: - * @{$variables$Variables: Creation, Initialization, Saving, Loading, and - Sharing}, which details the mechanics of TensorFlow Variables. - * @{$dims_types$Tensor Ranks, Shapes, and Types}, which explains Tensor - rank (the number of dimensions), shape (the size of each dimension), - and datatypes. - * @{$threading_and_queues$Threading and Queues}, which explains TensorFlow's - rich queuing system. - * @{$reading_data$Reading Data}, which documents three different mechanisms - for getting data into a TensorFlow program. - -The following guide is helpful when training a complex model over multiple -days: - - * @{$supervisor$Supervisor: Training Helper for Days-Long Trainings}, which - explains how to gracefully handle system crashes during a lengthy training - session. - -TensorFlow provides a debugger named `tfdbg`, which is documented in the -following guide: - - * @{$debugger$Debugging TensorFlow Programs}, - which walks you through the use of `tfdbg` within an application. It covers - using `tfdbg` with both the low-level TensorFlow API and the Estimator API. - -To learn about the TensorFlow versioning scheme consult: - - * @{$version_compat$The TensorFlow Version Compatibility Guide}, which explains -TensorFlow's versioning nomenclature and compatibility rules. - -We conclude this section with a FAQ about TensorFlow programming: - - * @{$faq$Frequently Asked Questions} + * @{$programmers_guide/tensors$Tensors}, which explains how to create, + manipulate, and access Tensors--the fundamental object in TensorFlow. + * @{$programmers_guide/variables$Variables}, which details how + to represent shared, persistent state in your program. + * @{$programmers_guide/graphs$Graphs and Sessions}, which explains: + * dataflow graphs, which are TensorFlow's representation of computations + as dependencies between operations. + * sessions, which are TensorFlow's mechanism for running dataflow graphs + across one or more local or remote devices. + If you are programming with the low-level TensorFlow API, this unit + is essential. If you are programming with a high-level TensorFlow API + such as Estimators or Keras, the high-level API creates and manages + graphs and sessions for you, but understanding graphs and sessions + can still be helpful. + * @{$programmers_guide/estimators$Estimators}, which introduces a high-level + TensorFlow API that greatly simplifies ML programming. + * @{$programmers_guide/saved_model$Saving and Restoring}, which + explains how to save and restore variables and models. + * @{$programmers_guide/datasets$Input Pipelines}, which explains how to + set up data pipelines to read data sets into your TensorFlow program. + * @{$programmers_guide/threading_and_queues$Threading and Queues}, which + explains TensorFlow's older system for multi-threaded, queue-based input + pipelines. Beginning with TensorFlow 1.2, we recommend using the + `tf.contrib.data` module instead, which is documented in the + "Input Pipelines" unit. + * @{$programmers_guide/embedding$Embeddings}, which introduces the concept + of embeddings, provides a simple example of training an embedding in + TensorFlow, and explains how to view embeddings with the TensorBoard + Embedding Projector. + * @{$programmers_guide/debugger$Debugging TensorFlow Programs}, which + explains how to use the TensorFlow debugger (tfdbg). + * @{$programmers_guide/supervisor$Supervisor: Training Helper for Days-Long Trainings}, + which explains how to gracefully handle system crashes during lengthy + training sessions. (We have not revised this document for v1.3.) + * @{$programmers_guide/version_compat$TensorFlow Version Compatibility}, + which explains backward compatibility guarantees and non-guarantees. + * @{$programmers_guide/faq$FAQ}, which contains frequently asked + questions about TensorFlow. (We have not revised this document for v1.3, + except to remove some obsolete information.) diff --git a/tensorflow/docs_src/programmers_guide/leftnav_files b/tensorflow/docs_src/programmers_guide/leftnav_files index 2a58c4647d1..5082e7f36c8 100644 --- a/tensorflow/docs_src/programmers_guide/leftnav_files +++ b/tensorflow/docs_src/programmers_guide/leftnav_files @@ -1,15 +1,13 @@ index.md tensors.md variables.md -dims_types.md graphs.md +estimators.md +saved_model.md datasets.md threading_and_queues.md -reading_data.md embedding.md debugger.md supervisor.md -saved_model.md -meta_graph.md version_compat.md faq.md From 711be6adcffde0688e3bf04b791b517a28fc5045 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 17 Aug 2017 20:21:45 -0700 Subject: [PATCH 14/70] `Dataset.from_generator()` constructs a dataset from a Python generator. With this change, it becomes possible to use a Python generator as the source dataset for a `tf.contrib.data` input pipeline. This enables easier integration with non-TensorFlow data sources. The generator can yield a nested structure of NumPy arrays, or values convertible to NumPy arrays. This addresses a concern raised in issue #7951. PiperOrigin-RevId: 165663857 --- .../dataset_constructor_op_test.py | 210 ++++++++++++++++++ tensorflow/contrib/data/python/ops/BUILD | 1 + .../contrib/data/python/ops/dataset_ops.py | 170 ++++++++++++++ tensorflow/core/kernels/map_dataset_op.cc | 12 +- 4 files changed, 391 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index 6a7bc99fa88..1de2f8e4da5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading + import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops @@ -255,6 +257,214 @@ class DatasetConstructorTest(test.TestCase): self.assertEquals(dtypes.int64, get_next.dtype) self.assertEquals([3], get_next.shape) + def _testFromGenerator(self, generator, elem_sequence, num_repeats): + iterator = ( + dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) + .repeat(num_repeats) + .prefetch(5) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + for _ in range(2): # Run twice to test reinitialization. + sess.run(init_op) + for _ in range(num_repeats): + for elem in elem_sequence: + self.assertAllEqual(elem, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats): + iterator = ( + dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) + .repeat(num_repeats) + .prefetch(5) + .make_one_shot_iterator()) + get_next = iterator.get_next() + + with self.test_session() as sess: + for _ in range(num_repeats): + for elem in elem_sequence: + self.assertAllEqual(elem, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromGeneratorUsingFunction(self): + def generator(): + for i in range(1, 100): + yield [i] * i + elem_sequence = list(generator()) + self._testFromGenerator(generator, elem_sequence, 1) + self._testFromGenerator(generator, elem_sequence, 5) + self._testFromGeneratorOneShot(generator, elem_sequence, 1) + self._testFromGeneratorOneShot(generator, elem_sequence, 5) + + def testFromGeneratorUsingList(self): + generator = lambda: [[i] * i for i in range(1, 100)] + elem_sequence = list(generator()) + self._testFromGenerator(generator, elem_sequence, 1) + self._testFromGenerator(generator, elem_sequence, 5) + + def testFromGeneratorUsingNdarray(self): + generator = lambda: np.arange(100, dtype=np.int64) + elem_sequence = list(generator()) + self._testFromGenerator(generator, elem_sequence, 1) + self._testFromGenerator(generator, elem_sequence, 5) + + def testFromGeneratorUsingGeneratorExpression(self): + # NOTE(mrry): Generator *expressions* are not repeatable (or in + # general reusable), because they eagerly evaluate the `for` + # expression as `iter(range(1, 100))` and discard the means of + # reconstructing `range(1, 100)`. Wrapping the generator + # expression in a `lambda` makes it repeatable. + generator = lambda: ([i] * i for i in range(1, 100)) + elem_sequence = list(generator()) + self._testFromGenerator(generator, elem_sequence, 1) + self._testFromGenerator(generator, elem_sequence, 5) + + def testFromMultipleConcurrentGenerators(self): + num_inner_repeats = 5 + num_outer_repeats = 100 + + def generator(): + for i in range(1, 10): + yield ([i] * i, [i, i ** 2, i ** 3]) + input_list = list(generator()) + + # The interleave transformation is essentially a flat map that + # draws from multiple input datasets concurrently (in a cyclic + # fashion). By placing `Datsaet.from_generator()` inside an + # interleave, we test its behavior when multiple iterators are + # active at the same time; by additionally prefetching inside the + # interleave, we create the possibility of parallel (modulo GIL) + # invocations to several iterators created by the same dataset. + def interleave_fn(_): + return (dataset_ops.Dataset.from_generator( + generator, output_types=(dtypes.int64, dtypes.int64), + output_shapes=([None], [3])) + .repeat(num_inner_repeats).prefetch(5)) + + iterator = ( + dataset_ops.Dataset.range(num_outer_repeats) + .interleave(interleave_fn, cycle_length=10, + block_length=len(input_list)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(num_inner_repeats * num_outer_repeats): + for elem in input_list: + val0, val1 = sess.run(get_next) + self.assertAllEqual(elem[0], val0) + self.assertAllEqual(elem[1], val1) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromGeneratorsRunningInParallel(self): + num_parallel_iterators = 3 + + # Define shared state that multiple iterator instances will access to + # demonstrate their concurrent activity. + lock = threading.Lock() + condition = threading.Condition(lock) + next_ticket = [0] # GUARDED_BY(lock) + + def generator(): + # NOTE(mrry): We yield one element before the barrier, because + # the current implementation of `Dataset.interleave()` must + # fetch one element from each incoming dataset to start the + # prefetching. + yield 0 + + # Define a barrier that `num_parallel_iterators` iterators must enter + # before any can proceed. Demonstrates that multiple iterators may be + # active at the same time. + condition.acquire() + ticket = next_ticket[0] + next_ticket[0] += 1 + if ticket == num_parallel_iterators - 1: + # The last iterator to join the barrier notifies the others. + condition.notify_all() + else: + # Wait until the last iterator enters the barrier. + while next_ticket[0] < num_parallel_iterators: + condition.wait() + condition.release() + + yield 1 + + # As in `testFromMultipleConcurrentGenerators()`, we use a combination of + # `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple + # iterators to be active concurrently. + def interleave_fn(_): + return dataset_ops.Dataset.from_generator( + generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2) + + iterator = ( + dataset_ops.Dataset.range(num_parallel_iterators) + .interleave( + interleave_fn, cycle_length=num_parallel_iterators, block_length=1) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for elem in [0, 1]: + for _ in range(num_parallel_iterators): + self.assertAllEqual(elem, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromGeneratorTypeError(self): + def generator(): + yield np.array([1, 2, 3], dtype=np.int64) + yield np.array([4, 5, 6], dtype=np.int64) + yield "ERROR" + yield np.array([7, 8, 9], dtype=np.int64) + + iterator = (dataset_ops.Dataset.from_generator( + generator, output_types=dtypes.int64, output_shapes=[3]) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + self.assertAllEqual([1, 2, 3], sess.run(get_next)) + self.assertAllEqual([4, 5, 6], sess.run(get_next)) + with self.assertRaisesOpError(r"element of type .*int64.* was expected"): + sess.run(get_next) + self.assertAllEqual([7, 8, 9], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromGeneratorShapeError(self): + def generator(): + yield np.array([1, 2, 3], dtype=np.int64) + yield np.array([4, 5, 6], dtype=np.int64) + yield np.array([7, 8, 9, 10], dtype=np.int64) + yield np.array([11, 12, 13], dtype=np.int64) + + iterator = (dataset_ops.Dataset.from_generator( + generator, output_types=dtypes.int64, output_shapes=[3]) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + self.assertAllEqual([1, 2, 3], sess.run(get_next)) + self.assertAllEqual([4, 5, 6], sess.run(get_next)) + with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): + sess.run(get_next) + self.assertAllEqual([11, 12, 13], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index f49350505ae..8afd122d82d 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -24,6 +24,7 @@ py_library( "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 6ef960037f0..ed3359730c4 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function import abc +import collections +import threading import warnings import numpy as np @@ -40,6 +42,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import script_ops from tensorflow.python.platform import gfile @@ -559,6 +562,168 @@ class Dataset(object): """ return SparseTensorSliceDataset(sparse_tensor) + class _GeneratorState(object): + """Stores outstanding iterators created from a Python generator. + + This class keeps track of potentially multiple iterators that may have + been created from a generator, e.g. in the case that the dataset is + repeated, or nested within a parallel computation. + """ + + def __init__(self, generator): + self._generator = generator + self._lock = threading.Lock() + self._next_id = 0 # GUARDED_BY(self._lock) + self._iterators = collections.defaultdict(lambda: iter(generator())) + + def get_next_id(self): + with self._lock: + ret = self._next_id + self._next_id += 1 + return ret + + def get_iterator(self, iterator_id): + return self._iterators[iterator_id] + + def iterator_completed(self, iterator_id): + del self._iterators[iterator_id] + + @staticmethod + def from_generator(generator, output_types, output_shapes=None): + """Creates a `Dataset` whose elements are generated by `generator`. + + The `generator` argument must be a callable object that returns + an object that support the `iter()` protocol (e.g. a generator function). + The elements generated by `generator` must be compatible with the given + `output_types` and (optional) `output_shapes` arguments. + + Args: + generator: A callable object that takes no arguments and returns an + object that supports the `iter()` protocol. + output_types: A nested structure of `tf.DType` objects corresponding to + each component of an element yielded by `generator`. + output_shapes: (Optional.) A nested structure of `tf.TensorShape` + objects corresponding to each component of an element yielded by + `generator`. + + Returns: + A `Dataset`. + """ + if not callable(generator): + raise TypeError("`generator` must be callable.") + if output_shapes is None: + output_shapes = nest.map_structure( + lambda _: tensor_shape.TensorShape(None), output_types) + else: + output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) + + flattened_types = nest.flatten(output_types) + flattened_shapes = nest.flatten(output_shapes) + + generator_state = Dataset._GeneratorState(generator) + + def get_iterator_id_map_fn(unused_dummy): + """Creates a unique `iterator_id` for each pass over the dataset. + + The "iterator_id" disambiguates between multiple concurrently + existing iterators. + + Args: + unused_dummy: Ignored value. + + Returns: + A `tf.int64` tensor whose value uniquely identifies an iterator in + `generator_state`. + """ + return script_ops.py_func( + generator_state.get_next_id, [], dtypes.int64, stateful=True) + + def generator_map_fn(iterator_id_t): + """Generates the next element from iterator with ID `iterator_id_t`. + + We map this function across an infinite repetition of the + `iterator_id_t`, and raise `StopIteration` to terminate the iteration. + + Args: + iterator_id_t: A `tf.int64` tensor whose value uniquely identifies + the iterator in `generator_state` from which to generate an element. + + Returns: + A nested structure of tensors representing an element from the iterator. + """ + def generator_py_func(iterator_id): + """A `py_func` that will be called to invoke the iterator.""" + try: + values = next(generator_state.get_iterator(iterator_id)) + except StopIteration: + generator_state.iterator_completed(iterator_id) + raise StopIteration("Iteration finished.") + + # Use the same _convert function from the py_func() implementation to + # convert the returned values to arrays early, so that we can inspect + # their values. + # pylint: disable=protected-access + ret_arrays = [script_ops.FuncRegistry._convert(ret) + for ret in nest.flatten_up_to(output_types, values)] + # pylint: enable=protected-access + + # Additional type and shape checking to ensure that the components + # of the generated element match the `output_types` and `output_shapes` + # arguments. + for (ret_array, expected_dtype, expected_shape) in zip( + ret_arrays, flattened_types, flattened_shapes): + if ret_array.dtype != expected_dtype.as_numpy_dtype: + raise TypeError( + "`generator` yielded an element of type %s where an element " + "of type %s was expected." + % (ret_array.dtype, expected_dtype.as_numpy_dtype)) + if not expected_shape.is_compatible_with(ret_array.shape): + raise ValueError( + "`generator` yielded an element of shape %s where an element " + "of shape %s was expected." % (ret_array.shape, expected_shape)) + + return ret_arrays + + flat_values = script_ops.py_func( + generator_py_func, [iterator_id_t], flattened_types, stateful=True) + + # The `py_func()` op drops the inferred shapes, so we add them back in + # here. + if output_shapes is not None: + for ret_t, shape in zip(flat_values, flattened_shapes): + ret_t.set_shape(shape) + + return nest.pack_sequence_as(output_types, flat_values) + + # This function associates each traversal of `generator` with a unique + # iterator ID. + def flat_map_fn(iterator_id_t): + # First, generate an infinite dataset containing the iterator ID repeated + # forever. + repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None) + + # The `generator_map_fn` gets the next element from the iterator with the + # relevant ID, and raises StopIteration when that iterator contains no + # more elements. + return repeated_id.map(generator_map_fn) + + # A single-element dataset that, each time it is evaluated, contains a + # freshly-generated and unique (for the returned dataset) int64 + # ID that will be used to identify the appropriate Python state, which + # is encapsulated in `generator_state`, and captured in + # `get_iterator_id_map_fn`. + dummy = 0 + id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn) + + # A dataset that contains all of the elements generated by a + # single iterator created from `generator`, identified by the + # iterator ID contained in `id_dataset`. Lifting the iteration + # into a flat_map here enables multiple repetitions and/or nested + # versions of the returned dataset to be created, because it forces + # the generation of a new ID for each version. + return id_dataset.flat_map(flat_map_fn) + @staticmethod def range(*args): """Creates a `Dataset` of a step-separated range of values. @@ -1123,6 +1288,11 @@ class Dataset(object): } ``` + NOTE: The order of elements yielded by this transformation is + deterministic, as long as `map_func` is a pure function. If + `map_func` contains any stateful operations, the order in which + that state is accessed is undefined. + Args: map_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc index 13a1ceaadff..bd6b0bce889 100644 --- a/tensorflow/core/kernels/map_dataset_op.cc +++ b/tensorflow/core/kernels/map_dataset_op.cc @@ -127,8 +127,16 @@ class MapDatasetOp : public UnaryDatasetOpKernel { opts.runner = ctx->runner(); // TODO(mrry): Avoid blocking a threadpool thread. We will need to // stack-rip the iterators and use async kernels. - return dataset()->captured_func_->Run(opts, args, out_tensors, - prefix()); + Status s = + dataset()->captured_func_->Run(opts, args, out_tensors, prefix()); + if (errors::IsOutOfRange(s)) { + // `f` may deliberately raise `errors::OutOfRange` to indicate + // that we should terminate the iteration early. + *end_of_sequence = true; + return Status::OK(); + } else { + return s; + } } private: From 573b303ac8204d626bee266798e1eb3df0fed491 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 18 Aug 2017 03:20:39 -0700 Subject: [PATCH 15/70] BUILD cleanup in tensorflow/core/kernels PiperOrigin-RevId: 165688864 --- tensorflow/core/kernels/BUILD | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 8d96999f3bb..d833ed9e38a 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -145,9 +145,7 @@ cc_library( "concat_lib.h", "concat_lib_cpu.h", ], - deps = [ - "//third_party/eigen3", - ], + deps = ["//third_party/eigen3"], ) cc_library( @@ -229,8 +227,11 @@ cc_library( hdrs = ["ops_testutil.h"], deps = [ "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", ], @@ -251,9 +252,7 @@ cc_library( cc_library( name = "ops_util_hdrs", hdrs = ["ops_util.h"], - deps = [ - "//third_party/eigen3", - ], + deps = ["//third_party/eigen3"], ) cc_library( @@ -402,6 +401,7 @@ cc_library( "split_lib.h", ], deps = [ + "//tensorflow/core:framework_lite", "//third_party/eigen3", ], ) @@ -411,6 +411,7 @@ cc_library( hdrs = ["typed_queue.h"], deps = [ ":queue_base", + "//tensorflow/core:framework", ], ) @@ -461,6 +462,8 @@ cc_library( ], visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", ], ) @@ -488,6 +491,8 @@ cc_library( hdrs = ["image_resizer_state.h"], visibility = ["//visibility:private"], deps = [ + ":bounds_check", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/eigen3", ], @@ -799,6 +804,7 @@ tf_kernel_library( "tile_functor_gpu.cu.cc", ], prefix = "tile_ops", + textual_hdrs = ["tile_ops_gpu_impl.h"], deps = ARRAY_DEPS, ) @@ -1680,6 +1686,7 @@ cc_library( "conditional_accumulator_base_op.h", ], deps = [ + ":conditional_accumulator_base", ":fill_functor", ":typed_conditional_accumulator_base", ], @@ -3128,6 +3135,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//third_party/eigen3", ], ) @@ -3547,7 +3555,10 @@ cc_library( "smooth-hinge-loss.h", "squared-loss.h", ], - deps = ["//tensorflow/core:framework_headers_lib"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + ], ) cc_test( From a6729325a3534ef4aeb2065be82bb2963b9b03de Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 18 Aug 2017 07:39:41 -0700 Subject: [PATCH 16/70] Deletes convert_n_to_eager_tensor. Moves convert_to_eager_tensor to constant_op. PiperOrigin-RevId: 165704074 --- tensorflow/python/eager/ops_test.py | 4 +- .../python/eager/python_eager_op_gen.cc | 4 +- tensorflow/python/eager/tensor.py | 2 - tensorflow/python/framework/constant_op.py | 23 +++++++-- tensorflow/python/framework/ops.py | 49 +++++-------------- 5 files changed, 34 insertions(+), 48 deletions(-) diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index dee339f7f19..78ff2f67771 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -272,9 +272,7 @@ class TargetTest(test_util.TensorFlowTestCase): def testInvalidInputDataType(self): # Fill requires the first input to be an int32 tensor. - with self.assertRaisesRegexp( - TypeError, - 'Expected tensor with type tf.int32 not tf.int64'): + with self.assertRaisesRegexp(ValueError, 'int64'): array_ops.fill(tensor.Tensor([2], dtype=dtypes.int64), tensor.Tensor(1)) def testOutputOnHostMemory(self): diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index 511ce82eeba..c46a3d8db37 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -624,8 +624,8 @@ void GenEagerPythonOp::AddEagerInputCasts() { const string fn = arg.number_attr().empty() ? "" : "n_"; const string dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); - strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn, - "to_eager_tensor(", param, ", ", dtype, ")\n"); + strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn, + "to_tensor(", param, ", ", dtype, ")\n"); } } diff --git a/tensorflow/python/eager/tensor.py b/tensorflow/python/eager/tensor.py index 1c2f4d74c7c..69269d1975f 100644 --- a/tensorflow/python/eager/tensor.py +++ b/tensorflow/python/eager/tensor.py @@ -24,8 +24,6 @@ import numpy as np # ops.py. # pylint: disable=unused-import from tensorflow.python.framework.ops import _tensor_from_handle -from tensorflow.python.framework.ops import convert_n_to_eager_tensor -from tensorflow.python.framework.ops import convert_to_eager_tensor from tensorflow.python.framework.ops import EagerTensor as Tensor # pylint: enable=unused-import diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index af3be7230c2..9de63607e12 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -41,6 +41,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from autograd import core as ag_core import numpy as np from tensorflow.core.framework import attr_value_pb2 @@ -66,13 +67,29 @@ def _eager_reshape(tensor, shape): def _eager_fill(dims, value): """Eager-only version of Fill op; requires value is an eager Tensor.""" attr_t = value.dtype.as_datatype_enum - dims = ops.convert_to_eager_tensor(dims, dtypes.int32) + dims = convert_to_eager_tensor(dims, dtypes.int32) inputs_flat = [dims, value] attrs = ("T", attr_t) result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs) return result +def convert_to_eager_tensor(t, dtype=None): + """Converts the given `value` to an `EagerTensor`.""" + if isinstance(ag_core.getval(t), ops.EagerTensor): + if dtype is not None and t.dtype != dtype: + raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) + return t + # Handle converting ResourceVariable to Tensor. + # TODO(josh11b): get rid of this explicit ugly conversion once we have a more + # general scheme in place. + try: + return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access + except AttributeError: + pass + return ops.EagerTensor(t, dtype=dtype) + + def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. @@ -123,8 +140,8 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """ if not context.in_graph_mode(): if shape is None: - return ops.convert_to_eager_tensor(value, dtype) - t = ops.convert_to_eager_tensor(value, dtype) + return convert_to_eager_tensor(value, dtype) + t = convert_to_eager_tensor(value, dtype) shape = tensor_shape.as_shape(shape) if shape == t.shape: return t diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 862dd706f41..6f1954537ec 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -876,29 +876,6 @@ class EagerTensor(Tensor): raise NotImplementedError("eval not supported for Eager Tensors.") -# TODO(josh11b): Support other cases like converting TensorShape, lists/tuples and -# other custom conversion functions. -def convert_to_eager_tensor(t, dtype=None): - """Converts the given `value` to an `EagerTensor`.""" - if isinstance(ag_core.getval(t), EagerTensor): - if dtype is not None and t.dtype != dtype: - raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) - return t - # Handle converting ResourceVariable to Tensor. - # TODO(josh11b): get rid of this explicit ugly conversion once we have a more - # general scheme in place. - try: - return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access - except AttributeError: - pass - return EagerTensor(t, dtype=dtype) - - -def convert_n_to_eager_tensor(values, dtype): - """Converts the given `values` to a list of `EagerTensor`.""" - return [convert_to_eager_tensor(t, dtype) for t in values] - - def _tensor_from_handle(handle): """'Private' constructor for the Tensor object. @@ -1112,21 +1089,17 @@ def internal_convert_n_to_tensor(values, """ if not isinstance(values, collections.Sequence): raise TypeError("values must be a list.") - if context.in_graph_mode(): - ret = [] - for i, value in enumerate(values): - n = None if name is None else "%s_%d" % (name, i) - ret.append( - internal_convert_to_tensor( - value, - dtype=dtype, - name=n, - as_ref=as_ref, - preferred_dtype=preferred_dtype)) - return ret - else: - # TODO(josh11b): handle preferred_dtype, as_ref - return convert_n_to_eager_tensor(values, dtype=dtype) + ret = [] + for i, value in enumerate(values): + n = None if name is None else "%s_%d" % (name, i) + ret.append( + internal_convert_to_tensor( + value, + dtype=dtype, + name=n, + as_ref=as_ref, + preferred_dtype=preferred_dtype)) + return ret def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): From 7d01f89cc3a05fbd4d79dd5713b9856a8e2764e1 Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Fri, 18 Aug 2017 09:32:30 -0700 Subject: [PATCH 17/70] Android demo app for speech recognition PiperOrigin-RevId: 165714459 --- WORKSPACE | 10 + tensorflow/contrib/makefile/Makefile | 19 +- .../contrib/makefile/download_dependencies.sh | 2 + tensorflow/contrib/makefile/tf_op_files.txt | 8 + tensorflow/core/BUILD | 2 + tensorflow/core/kernels/BUILD | 14 + .../docs_src/tutorials/audio_recognition.md | 47 ++- .../examples/android/AndroidManifest.xml | 10 + tensorflow/examples/android/BUILD | 1 + tensorflow/examples/android/README.md | 116 +++--- .../examples/android/download-models.gradle | 3 +- .../examples/android/res/drawable/border.xml | 19 + .../android/res/layout/activity_speech.xml | 55 +++ .../android/res/layout/list_text_item.xml | 25 ++ .../android/res/values/base-strings.xml | 1 + .../tensorflow/demo/RecognizeCommands.java | 186 +++++++++ .../org/tensorflow/demo/SpeechActivity.java | 353 ++++++++++++++++++ 17 files changed, 807 insertions(+), 64 deletions(-) create mode 100644 tensorflow/examples/android/res/drawable/border.xml create mode 100644 tensorflow/examples/android/res/layout/activity_speech.xml create mode 100644 tensorflow/examples/android/res/layout/list_text_item.xml create mode 100644 tensorflow/examples/android/src/org/tensorflow/demo/RecognizeCommands.java create mode 100644 tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java diff --git a/WORKSPACE b/WORKSPACE index 959587387ee..5e9b991fcca 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -80,3 +80,13 @@ new_http_archive( "http://download.tensorflow.org/models/stylize_v1.zip", ], ) + +new_http_archive( + name = "speech_commands", + build_file = "models.BUILD", + sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c", + urls = [ + "http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", + "http://download.tensorflow.org/models/speech_commands_v0.01.zip", + ], +) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index a4f7453ed5c..f8837e3f586 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -73,8 +73,9 @@ HOST_INCLUDES := \ -I. \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/eigen \ - -I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/gemmlowp \ -I$(MAKEFILE_DIR)/downloads/nsync/public \ +-I$(MAKEFILE_DIR)/downloads/fft2d \ -I$(HOST_GENDIR) ifeq ($(HAS_GEN_HOST_PROTOC),true) HOST_INCLUDES += -I$(MAKEFILE_DIR)/gen/protobuf-host/include @@ -156,6 +157,7 @@ INCLUDES := \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ -I$(MAKEFILE_DIR)/downloads/nsync/public \ +-I$(MAKEFILE_DIR)/downloads/fft2d \ -I$(PROTOGENDIR) \ -I$(PBTGENDIR) ifeq ($(HAS_GEN_HOST_PROTOC),true) @@ -237,6 +239,7 @@ ifeq ($(TARGET),ANDROID) $(error "NDK_ROOT is not defined.") endif CXX := $(CC_PREFIX) $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-g++ + CC := $(CC_PREFIX) $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-gcc CXXFLAGS +=\ --sysroot $(NDK_ROOT)/platforms/android-21/arch-arm \ -Wno-narrowing \ @@ -244,7 +247,6 @@ ifeq ($(TARGET),ANDROID) -mfloat-abi=softfp \ -mfpu=neon \ -fPIE - INCLUDES = \ -I$(NDK_ROOT)/sources/android/support/include \ -I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/include \ @@ -254,6 +256,7 @@ ifeq ($(TARGET),ANDROID) -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ -I$(MAKEFILE_DIR)/downloads/nsync/public \ +-I$(MAKEFILE_DIR)/downloads/fft2d \ -I$(MAKEFILE_DIR)/gen/protobuf/include \ -I$(PROTOGENDIR) \ -I$(PBTGENDIR) @@ -507,6 +510,7 @@ $(wildcard tensorflow/core/grappler/clusters/single_machine.*) TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # Add in any extra files that don't fit the patterns easily TF_CC_SRCS += tensorflow/core/platform/default/gpu_tracer.cc +TF_CC_SRCS += tensorflow/contrib/makefile/downloads/fft2d/fftsg.c # Also include the op and kernel definitions. TF_CC_SRCS += $(shell cat $(MAKEFILE_DIR)/tf_op_files.txt) PBT_CC_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_pb_text_files.txt) @@ -529,7 +533,8 @@ tensorflow/core/kernels/hexagon/hexagon_remote_fused_graph_executor_build.cc endif # File names of the intermediate files target compilation generates. -TF_CC_OBJS := $(addprefix $(OBJDIR), $(TF_CC_SRCS:.cc=.o)) +TF_CC_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_CC_SRCS)))) PBT_GEN_FILES := $(addprefix $(PBTGENDIR), $(PBT_CC_SRCS)) PBT_OBJS := $(addprefix $(OBJDIR), $(PBT_CC_SRCS:.cc=.o)) PROTO_CC_SRCS := $(addprefix $(PROTOGENDIR), $(PROTO_SRCS:.proto=.pb.cc)) @@ -567,6 +572,14 @@ $(OBJDIR)%.o: %.cc | $(PBT_GEN_FILES) $(CXX) $(CXXFLAGS) $(DEPFLAGS) $(INCLUDES) -c $< -o $@ @mv -f $(DEPDIR)/$*.Td $(DEPDIR)/$*.d +# Matches on plain C files. +$(OBJDIR)%.o: %.c + @mkdir -p $(dir $@) + @mkdir -p $(dir $(DEPDIR)$*) + $(CXX) $(patsubst --std=c++11,--std=c99, $(CXXFLAGS)) -x c $(DEPFLAGS) \ +$(INCLUDES) -c $< -o $@ + @mv -f $(DEPDIR)/$*.Td $(DEPDIR)/$*.d + # Compiles C++ source files that have been generated by protoc. $(OBJDIR)%.pb.o: $(PROTOGENDIR)%.pb.cc @mkdir -p $(dir $@) diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index bb30a3b5a7b..1e9958584c9 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -25,6 +25,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g NSYNC_URL="$(grep -o 'http.*github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" PROTOBUF_URL="$(grep -o 'http.*github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" RE2_URL="$(grep -o 'http.*github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. @@ -60,6 +61,7 @@ download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync" download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf" download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2" +download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 9132a4344bf..a7f2be9790d 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -38,6 +38,8 @@ tensorflow/core/kernels/stack_ops.cc tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc tensorflow/core/kernels/split_lib_cpu.cc +tensorflow/core/kernels/spectrogram_op.cc +tensorflow/core/kernels/spectrogram.cc tensorflow/core/kernels/sparse_to_dense_op.cc tensorflow/core/kernels/sparse_matmul_op.cc tensorflow/core/kernels/softsign_op.cc @@ -100,6 +102,10 @@ tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc +tensorflow/core/kernels/mfcc_op.cc +tensorflow/core/kernels/mfcc_mel_filterbank.cc +tensorflow/core/kernels/mfcc_dct.cc +tensorflow/core/kernels/mfcc.cc tensorflow/core/kernels/maxpooling_op.cc tensorflow/core/kernels/matmul_op.cc tensorflow/core/kernels/lrn_op.cc @@ -117,6 +123,7 @@ tensorflow/core/kernels/fill_functor.cc tensorflow/core/kernels/fifo_queue.cc tensorflow/core/kernels/fake_quant_ops.cc tensorflow/core/kernels/example_parsing_ops.cc +tensorflow/core/kernels/encode_wav_op.cc tensorflow/core/kernels/dynamic_stitch_op.cc tensorflow/core/kernels/dynamic_partition_op.cc tensorflow/core/kernels/decode_bmp_op.cc @@ -124,6 +131,7 @@ tensorflow/core/kernels/depthtospace_op.cc tensorflow/core/kernels/spacetodepth_op.cc tensorflow/core/kernels/dense_update_ops.cc tensorflow/core/kernels/deep_conv2d.cc +tensorflow/core/kernels/decode_wav_op.cc tensorflow/core/kernels/xsmm_conv2d.cc tensorflow/core/kernels/cwise_ops_common.cc tensorflow/core/kernels/cwise_op_tanh.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index f7b79e82e16..54f2ff7e132 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -981,6 +981,8 @@ cc_library( deps = [ ":protos_cc", "//third_party/eigen3", + "//third_party/fft2d:fft2d_headers", + "@fft2d//:fft2d", "@gemmlowp//:gemmlowp", "@nsync//:nsync_cpp", ], diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index d833ed9e38a..9f638eebee4 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4322,6 +4322,9 @@ filegroup( "gemm_functors.h", "image_resizer_state.h", "maxpooling_op.h", + "mfcc.h", + "mfcc_dct.h", + "mfcc_mel_filterbank.h", "mirror_pad_op.h", "mirror_pad_op_cpu_impl.h", "pad_op.h", @@ -4338,6 +4341,7 @@ filegroup( "softsign_op.h", "spacetobatch_functor.h", "spacetodepth_op.h", + "spectrogram.h", "tensor_array.h", "tile_functor.h", "tile_ops_cpu_impl.h", @@ -4411,10 +4415,12 @@ filegroup( "cwise_op_squared_difference.cc", "cwise_op_sub.cc", "cwise_op_tanh.cc", + "decode_wav_op.cc", "deep_conv2d.cc", "deep_conv2d.h", "depthwise_conv_op.cc", "dynamic_partition_op.cc", + "encode_wav_op.cc", "fake_quant_ops.cc", "fifo_queue.cc", "fused_batch_norm_op.cc", @@ -4443,6 +4449,10 @@ filegroup( "logging_ops.cc", "lrn_op.cc", "maxpooling_op.cc", + "mfcc.cc", + "mfcc_dct.cc", + "mfcc_mel_filterbank.cc", + "mfcc_op.cc", "mirror_pad_op.cc", "mirror_pad_op_cpu_impl_1.cc", "mirror_pad_op_cpu_impl_2.cc", @@ -4478,6 +4488,8 @@ filegroup( "spacetobatch_op.cc", "spacetodepth_op.cc", "sparse_to_dense_op.cc", + "spectrogram.cc", + "spectrogram_op.cc", "stack_ops.cc", "string_join_op.cc", "summary_op.cc", @@ -4614,6 +4626,8 @@ cc_library( "//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:protos_cc", "//third_party/eigen3", + "//third_party/fft2d:fft2d_headers", + "@fft2d//:fft2d", "@gemmlowp//:gemmlowp", ], alwayslink = 1, diff --git a/tensorflow/docs_src/tutorials/audio_recognition.md b/tensorflow/docs_src/tutorials/audio_recognition.md index 57d3ebb9968..2caa3ec0d2d 100644 --- a/tensorflow/docs_src/tutorials/audio_recognition.md +++ b/tensorflow/docs_src/tutorials/audio_recognition.md @@ -214,6 +214,41 @@ of the other .wav files in that same folder to see how well it does. The scores are between zero and one, and higher values mean the model is more confident in its prediction. +## Running the Model in an Android App + +The easiest way to see how this model works in a real application is to download +[the prebuilt Android demo +applications](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#prebuilt-components) +and install them on your phone. You'll see 'TF Speech' appear in your app list, +and opening it will show you the same list of action words we've just trained +our model on, starting with "Yes" and "No". Once you've given the app permission +to use the microphone, you should be able to try saying those words and see them +highlighted in the UI when the model recognizes one of them. + +You can also build this application yourself, since it's open source and +[available as part of the TensorFlow repository on +github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#building-in-android-studio-using-the-tensorflow-aar-from-jcenter). +By default it downloads [a pretrained model from +tensorflow.org](http://download.tensorflow.org/models/speech_commands_v0.01.zip), +but you can easily [replace it with a model you've trained +yourself](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-model-files-optional). +If you do this, you'll need to make sure that the constants in [the main +SpeechActivity Java source +file](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java) +like `SAMPLE_RATE` and `SAMPLE_DURATION` match any changes you've made to the +defaults while training. You'll also see that there's a [Java version of the +RecognizeCommands +module](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android/src/org/tensorflow/demo/RecognizeCommands.java) +that's very similar to the C++ version in this tutorial. If you've tweaked +parameters for that, you can also update them in SpeechActivity to get the same +results as in your server testing. + +The demo app updates its UI list of results automatically based on the labels +text file you copy into assets alongside your frozen graph, which means you can +easily try out different models without needing to make any code changes. You +will need to updaye `LABEL_FILENAME` and `MODEL_FILENAME` to point to the files +you've added if you change the paths though. + ## How does this Model Work? The architecture used in this tutorial is based on some described in the paper @@ -341,13 +376,14 @@ aren't detected (high precision). The numbers from the tool give you an idea of how your model will perform in an application, and you can try tweaking the signal averaging parameters to tune it to give the kind of performance you want. To understand what the right parameters are for your application, you can look -at generating an [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) -to help you understand the tradeoffs. +at generating an [ROC +curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) to help +you understand the tradeoffs. ## RecognizeCommands -The streaming accuracy tool uses a simple decoder contained in a small -C++ class called +The streaming accuracy tool uses a simple decoder contained in a small C++ class +called [RecognizeCommands](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands/recognize_commands.h). This class is fed the output of running the TensorFlow model over time, it averages the signals, and returns information about a label when it has enough @@ -480,7 +516,8 @@ variations in starting time in the training data, and is controlled with the `--time_shift_ms` flag, which defaults to 100ms. Increasing this value will provide more variation, but at the risk of cutting off important parts of the audio. A related way of augmenting the data with realistic distortions is by -using [time stretching and pitch scaling](https://en.wikipedia.org/wiki/Audio_time_stretching_and_pitch_scaling), +using [time stretching and pitch +scaling](https://en.wikipedia.org/wiki/Audio_time_stretching_and_pitch_scaling), but that's outside the scope of this tutorial. ## Customizing the Model diff --git a/tensorflow/examples/android/AndroidManifest.xml b/tensorflow/examples/android/AndroidManifest.xml index 9f229d8b9d4..bb75431a1f8 100644 --- a/tensorflow/examples/android/AndroidManifest.xml +++ b/tensorflow/examples/android/AndroidManifest.xml @@ -22,6 +22,7 @@ + + + + + + + + diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index 2d3b0911fce..2347e6b0231 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -93,6 +93,7 @@ filegroup( srcs = [ "@inception5h//:model_files", "@mobile_ssd//:model_files", + "@speech_commands//:model_files", "@stylize//:model_files", ], ) diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index f9881287cdf..883f8e664fd 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -8,10 +8,11 @@ devices. The demos in this folder are designed to give straightforward samples of using TensorFlow in mobile applications. -Inference is done using the [TensorFlow Android Inference Interface](../../../tensorflow/contrib/android), -which may be built separately if you want a standalone library to drop into your -existing application. Object tracking and efficient YUV -> RGB conversion are -handled by `libtensorflow_demo.so`. +Inference is done using the [TensorFlow Android Inference +Interface](../../../tensorflow/contrib/android), which may be built separately +if you want a standalone library to drop into your existing application. Object +tracking and efficient YUV -> RGB conversion are handled by +`libtensorflow_demo.so`. A device running Android 5.0 (API 21) or higher is required to run the demo due to the use of the camera2 API, although the native libraries themselves can run @@ -33,6 +34,12 @@ on API >= 14 devices. Uses a model based on [A Learned Representation For Artistic Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview image to that of a number of different artists. +4. [TF + Speech](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java): + Runs a simple speech recognition model built by the [audio training + tutorial](https://www.tensorflow.org/tutorials/image_retraining). Listens + for a small set of words, and highlights them in the UI when they are + recognized. @@ -51,20 +58,22 @@ for more details. ## Running the Demo -Once the app is installed it can be started via the "TF Classify", "TF Detect" -and "TF Stylize" icons, which have the orange TensorFlow logo as their icon. +Once the app is installed it can be started via the "TF Classify", "TF Detect", +"TF Stylize", and "TF Speech" icons, which have the orange TensorFlow logo as +their icon. While running the activities, pressing the volume keys on your device will -toggle debug visualizations on/off, rendering additional info to the screen -that may be useful for development purposes. +toggle debug visualizations on/off, rendering additional info to the screen that +may be useful for development purposes. ## Building in Android Studio using the TensorFlow AAR from JCenter The simplest way to compile the demo app yourself, and try out changes to the -project code is to use AndroidStudio. Simply set this `android` directory as the project root. +project code is to use AndroidStudio. Simply set this `android` directory as the +project root. -Then edit the `build.gradle` file and change the value of `nativeBuildSystem` -to `'none'` so that the project is built in the simplest way possible: +Then edit the `build.gradle` file and change the value of `nativeBuildSystem` to +`'none'` so that the project is built in the simplest way possible: ```None def nativeBuildSystem = 'none' @@ -77,8 +86,8 @@ Note: Currently, in this build mode, YUV -> RGB is done using a less efficient Java implementation, and object tracking is not available in the "TF Detect" activity. Setting the build system to `'cmake'` currently only builds `libtensorflow_demo.so`, which provides fast YUV -> RGB conversion and object -tracking, while still acquiring TensorFlow support via the downloaded AAR, so -it may be a lightweight way to enable these features. +tracking, while still acquiring TensorFlow support via the downloaded AAR, so it +may be a lightweight way to enable these features. For any project that does not include custom low level TensorFlow code, this is likely sufficient. @@ -104,50 +113,51 @@ protobuf compilation. NOTE: Bazel does not currently support building for Android on Windows. Full support for gradle/cmake builds is coming soon, but in the meantime we suggest -that Windows users download the -[prebuilt binaries](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) -instead. +that Windows users download the [prebuilt +binaries](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) instead. ##### Install Bazel and Android Prerequisites -Bazel is the primary build system for TensorFlow. To build with Bazel, -it and the Android NDK and SDK must be installed on your system. +Bazel is the primary build system for TensorFlow. To build with Bazel, it and +the Android NDK and SDK must be installed on your system. -1. Install the latest version of Bazel as per the instructions [on the Bazel website](https://bazel.build/versions/master/docs/install.html). -2. The Android NDK is required to build the native (C/C++) TensorFlow code. - The current recommended version is 12b, which may be found - [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-12b-downloads). -3. The Android SDK and build tools may be obtained - [here](https://developer.android.com/tools/revisions/build-tools.html), - or alternatively as part of - [Android Studio](https://developer.android.com/studio/index.html). Build - tools API >= 23 is required to build the TF Android demo (though it will - run on API >= 21 devices). +1. Install the latest version of Bazel as per the instructions [on the Bazel + website](https://bazel.build/versions/master/docs/install.html). +2. The Android NDK is required to build the native (C/C++) TensorFlow code. The + current recommended version is 12b, which may be found + [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-12b-downloads). +3. The Android SDK and build tools may be obtained + [here](https://developer.android.com/tools/revisions/build-tools.html), or + alternatively as part of [Android + Studio](https://developer.android.com/studio/index.html). Build tools API >= + 23 is required to build the TF Android demo (though it will run on API >= 21 + devices). ##### Edit WORKSPACE -The Android entries in [`/WORKSPACE`](../../../WORKSPACE#L19-L36) -must be uncommented with the paths filled in appropriately depending on where -you installed the NDK and SDK. Otherwise an error such as: -"The external label '//external:android/sdk' is not bound to anything" will -be reported. +The Android entries in +[`/WORKSPACE`](../../../WORKSPACE#L19-L36) must be uncommented +with the paths filled in appropriately depending on where you installed the NDK +and SDK. Otherwise an error such as: "The external label +'//external:android/sdk' is not bound to anything" will be reported. -Also edit the API levels for the SDK in WORKSPACE to the highest level you -have installed in your SDK. This must be >= 23 (this is completely independent -of the API level of the demo, which is defined in AndroidManifest.xml). -The NDK API level may remain at 14. +Also edit the API levels for the SDK in WORKSPACE to the highest level you have +installed in your SDK. This must be >= 23 (this is completely independent of the +API level of the demo, which is defined in AndroidManifest.xml). The NDK API +level may remain at 14. ##### Install Model Files (optional) -The TensorFlow `GraphDef`s that contain the model definitions and weights -are not packaged in the repo because of their size. They are downloaded +The TensorFlow `GraphDef`s that contain the model definitions and weights are +not packaged in the repo because of their size. They are downloaded automatically and packaged with the APK by Bazel via a new_http_archive defined -in `WORKSPACE` during the build process, and by Gradle via download-models.gradle. +in `WORKSPACE` during the build process, and by Gradle via +download-models.gradle. -**Optional**: If you wish to place the models in your assets manually, -remove all of the `model_files` entries from the `assets` -list in `tensorflow_demo` found in the `[BUILD](BUILD)` file. Then download -and extract the archives yourself to the `assets` directory in the source tree: +**Optional**: If you wish to place the models in your assets manually, remove +all of the `model_files` entries from the `assets` list in `tensorflow_demo` +found in the `[BUILD](BUILD)` file. Then download and extract the archives +yourself to the `assets` directory in the source tree: ```bash BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models @@ -162,27 +172,23 @@ This will extract the models and their associated metadata files to the local assets/ directory. If you are using Gradle, make sure to remove download-models.gradle reference -from build.gradle after your manually download models; otherwise gradle -might download models again and overwrite your models. +from build.gradle after your manually download models; otherwise gradle might +download models again and overwrite your models. ##### Build -After editing your WORKSPACE file to update the SDK/NDK configuration, -you may build the APK. Run this from your workspace root: +After editing your WORKSPACE file to update the SDK/NDK configuration, you may +build the APK. Run this from your workspace root: ```bash bazel build -c opt //tensorflow/examples/android:tensorflow_demo ``` -If you get build errors about protocol buffers, run -`git submodule update --init` and make sure that you've modified your WORKSPACE -file as instructed, then try building again. - ##### Install -Make sure that adb debugging is enabled on your Android 5.0 (API 21) or -later device, then after building use the following command from your workspace -root to install the APK: +Make sure that adb debugging is enabled on your Android 5.0 (API 21) or later +device, then after building use the following command from your workspace root +to install the APK: ```bash adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk diff --git a/tensorflow/examples/android/download-models.gradle b/tensorflow/examples/android/download-models.gradle index a19ca36d7f6..0e2cf65f538 100644 --- a/tensorflow/examples/android/download-models.gradle +++ b/tensorflow/examples/android/download-models.gradle @@ -11,7 +11,8 @@ // LINT.IfChange def models = ['inception5h.zip', 'object_detection/ssd_mobilenet_v1_android_export.zip', - 'stylize_v1.zip'] + 'stylize_v1.zip', + 'speech_commands_conv_actions.zip'] // LINT.ThenChange(//tensorflow/examples/android/BUILD) // Root URL for model archives diff --git a/tensorflow/examples/android/res/drawable/border.xml b/tensorflow/examples/android/res/drawable/border.xml new file mode 100644 index 00000000000..dd1d64d1d61 --- /dev/null +++ b/tensorflow/examples/android/res/drawable/border.xml @@ -0,0 +1,19 @@ + + + + + diff --git a/tensorflow/examples/android/res/layout/activity_speech.xml b/tensorflow/examples/android/res/layout/activity_speech.xml new file mode 100644 index 00000000000..2fe1338da57 --- /dev/null +++ b/tensorflow/examples/android/res/layout/activity_speech.xml @@ -0,0 +1,55 @@ + + + + + + + +