diff --git a/configure b/configure index 7b08dca2047..87ef6e99be3 100755 --- a/configure +++ b/configure @@ -49,6 +49,15 @@ while true; do # Retry done +## Set up architecture-dependent optimization flags. +if [ -z "$CC_OPT_FLAGS" ]; then + default_cc_opt_flags="-march=native" + read -p "Please specify optimization flags to use during compilation [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS + if [ -z "$CC_OPT_FLAGS" ]; then + CC_OPT_FLAGS=$default_cc_opt_flags + fi +fi + if is_windows; then TF_NEED_GCP=0 TF_NEED_HDFS=0 @@ -148,6 +157,12 @@ fi # Invoke python_config and set up symlinks to python includes ./util/python/python_config.sh --setup "$PYTHON_BIN_PATH" +# Append CC optimization flags to bazel.rc +echo >> tools/bazel.rc +for opt in $CC_OPT_FLAGS; do + echo "build:opt --cxxopt=$opt --copt=$opt" >> tools/bazel.rc +done + # Run the gen_git_source to create links where bazel can track dependencies for # git hash propagation GEN_GIT_SOURCE=tensorflow/tools/git/gen_git_source.py diff --git a/tensorflow/cc/framework/ops.cc b/tensorflow/cc/framework/ops.cc index 94c6b801776..50df891a4c4 100644 --- a/tensorflow/cc/framework/ops.cc +++ b/tensorflow/cc/framework/ops.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { -namespace ops { Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {} @@ -110,5 +109,4 @@ Input::Initializer::Initializer( tensor = t; } -} // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index 71bfc6617c1..82ba9c68f0a 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { -namespace ops { class Output; @@ -193,6 +192,7 @@ class Input { // * A scalar, or a multi-dimensional tensor specified as a recursive // initializer list. This enables directly passing constants as // inputs to op wrappers. + // * A Tensor object. Input(const Output& o) : output_(o) {} // NOLINT(runtime/explicit) template <typename T, typename = typename std::enable_if< @@ -249,7 +249,7 @@ typedef std::vector<Output> OutputList; class InputList { public: // Implicitly convert a list of outputs to a list of inputs. This is useful to - // write code such as tf.Concat(tf.Split(x, 4)). + // write code such as ops::Concat(ops::Split(x, 4)). InputList(const OutputList& out) { // NOLINT(runtime/explicit) for (auto const& x : out) { inputs_.push_back(x); @@ -284,7 +284,19 @@ class InputList { std::vector<Input> inputs_; }; +// These symbols used to live in the ops namespace, so we temporarily +// declare some aliases there. TODO(josh11b): Delete this! +namespace ops { + +using ::tensorflow::Input; +using ::tensorflow::InputList; +using ::tensorflow::Operation; +using ::tensorflow::Output; +using ::tensorflow::OutputHash; +using ::tensorflow::OutputList; + } // namespace ops + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 72c440abe88..c1e61462085 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -602,12 +602,10 @@ Status EncapsulateSubgraphsPass::Run( std::unique_ptr<Graph>* subgraph, std::vector<int>* input_permutation, std::vector<int>* output_permutation, NodeDef* node) { // Optimize the subgraph. - Graph* g = subgraph->release(); - OptimizeGraph(flr.get(), &g); - subgraph->reset(g); + OptimizeGraph(flr.get(), subgraph); std::vector<bool> const_args(input_permutation->size()); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*g, &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); // Compute a permutation of the arguments such that the constant arguments // are first. diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 41403858a69..abc0cb2cce7 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -28,7 +28,7 @@ limitations under the License. // Run tests, comparing the Tensorflow CPU operators with their XLA-compiled // counterparts: // randomized_tests \ -// --tf_xla_test_use_jit=true --tf_xla_test_device=CPU \ +// --tf_xla_test_use_jit=true --tf_xla_test_device=CPU:0 \ // --tf_xla_test_repetitions=20 // TODO(phawkins): add tests for: @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -66,6 +67,7 @@ limitations under the License. #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace { @@ -76,9 +78,8 @@ int32 tf_xla_test_repetitions = 20; string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; -string DeviceTypeToDeviceName(DeviceType type) { - return strings::StrCat("/job:localhost/replica:0/task:0/device:", type.type(), - ":0"); +string LocalDeviceToFullDeviceName(const string& device) { + return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); } constexpr std::array<DataType, 3> kAllXlaTypes = { @@ -575,9 +576,14 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, double atol, double rtol) { - string cpu_device = DeviceTypeToDeviceName(DEVICE_CPU); - DeviceType test_device_type(*tf_xla_test_device_ptr); - string test_device = DeviceTypeToDeviceName(test_device_type); + string cpu_device = + LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0")); + string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); + + DeviceNameUtils::ParsedName parsed_name; + ASSERT_TRUE( + DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)); + DeviceType test_device_type(parsed_name.type); ++num_tests_; GraphDef graph; @@ -2058,7 +2064,7 @@ TEST_F(OpTest, ZerosLike) { } // namespace tensorflow int main(int argc, char** argv) { - tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU"); + tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0"); std::vector<tensorflow::Flag> flag_list = { tensorflow::Flag( "tf_xla_random_seed", &tensorflow::tf_xla_random_seed, @@ -2085,13 +2091,18 @@ int main(int argc, char** argv) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; } - // XLA devices register kernels at construction time; create and destroy all - // known devices to make sure the kernels are registered. + // XLA devices register kernels at construction time; create all known devices + // to make sure the kernels are registered. std::vector<tensorflow::Device*> devices; TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( tensorflow::SessionOptions(), "", &devices)); - for (tensorflow::Device* device : devices) { - delete device; - } + tensorflow::DeviceMgr device_mgr(devices); + + tensorflow::Device* ignored; + TF_QCHECK_OK( + device_mgr.LookupDevice(*tensorflow::tf_xla_test_device_ptr, &ignored)) + << "Unknown test device (" << *tensorflow::tf_xla_test_device_ptr + << "). Did you build in the right configuration (e.g., is CUDA enabled)?"; + return RUN_ALL_TESTS(); } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 4d861c71c41..299b5e98c03 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -97,7 +97,6 @@ cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", - "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index a0edbc5cbc3..d291888a758 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -186,9 +186,7 @@ Status XlaCompiler::CompileFunctionBody( // for devices other than CPU. OptimizerOptions opts; GraphOptimizer optimizer(opts); - Graph* g = graph.release(); - OptimizeGraph(flr, &g); - graph.reset(g); + OptimizeGraph(flr, &graph); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile( diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 24efd3ed0b8..af0b9c478dd 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" -#include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/xla/client/client_library.h" diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index b08f859270a..dea6bb33d34 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -51,18 +51,9 @@ bool IsLiteralWithValue(const HloInstruction* operand, int value) { // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. -bool TransposeIsBitcast( - const HloInstruction* transpose, - const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { +bool TransposeIsBitcast(const HloInstruction* transpose) { CHECK_EQ(HloOpcode::kTranspose, transpose->opcode()); const HloInstruction* operand = transpose->operand(0); - - // Can't insert bitcasts if the compiler used a memory layout which isn't - // compatible. - if (!valid_bitcast_callback(operand->shape(), transpose->shape())) { - return false; - } - return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(), transpose->dimensions()); } @@ -80,11 +71,8 @@ bool ReshapeIsBitcast( const HloInstruction* operand = reshape->operand(0); // Can't insert bitcasts if the compiler used a memory layout which isn't // compatible. - if (!valid_bitcast_callback(operand->shape(), reshape->shape())) { - return false; - } - - return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()); + return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) && + valid_bitcast_callback(operand->shape(), reshape->shape()); } } // namespace @@ -199,7 +187,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Whether layout is considered during transformation. bool is_layout_sensitive_; - // Callback used to determine if a bitcast is valid. + // Callback used to determine if a bitcast is possible. AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; }; @@ -287,7 +275,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, HloInstruction* rhs) { // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); - if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { + if (IsLiteralWithValue(rhs, 1) && + ReplaceInstructionIfSameShape(divide, lhs)) { return Status::OK(); } @@ -717,8 +706,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } - if (is_layout_sensitive_ && - TransposeIsBitcast(transpose, valid_bitcast_callback_)) { + if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 4aa06cab53c..0519aced5fa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -26,9 +26,11 @@ namespace xla { // A pass which performs AlgebraicSimplications. class AlgebraicSimplifier : public HloPass { public: - // Given two shapes, determines if it is valid to bitcast between them. - // Precondition: the two shapes have layouts and have the same number of - // elements. + // Given two shapes, determines if it is valid to bitcast between them after + // considering platform dependent effects on layout like alignment + // restrictions. + // Precondition: the two shapes have layouts, the same number of + // elements and ShapeUtil::ReshapeIsBitcast returns true. using ValidBitcastCallback = std::function<bool(const Shape&, const Shape&)>; // If is_layout_sensitive is true, then the simplifier preserves layout during diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 4865a8fb45c..990173e2e52 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -495,11 +495,13 @@ string DumpGraph(const HloComputation& computation, const string& label, } void DumpText(const HloModule& module, const string& label, - const string& directory_path) { + const string& directory_path, bool do_prefix) { Env* env = Env::Default(); TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); string prefix = StrCat(env->NowMicros()); - string path = JoinPath(directory_path, StrCat(prefix, "-", label, ".txt")); + string filename = + do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); + string path = JoinPath(directory_path, filename); TF_CHECK_OK(WriteStringToFile(env, path, module.ToString())); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 45fd46352f9..5f841da1f35 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -33,8 +33,12 @@ string DumpGraph(const HloComputation& computation, const string& label, // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. +// +// If do_prefix is true, a timestamp will be prepended onto the label to +// construct a filename in the directory path; otherwise, the label is used +// as the filename directly. void DumpText(const HloModule& module, const string& label, - const string& directory_path); + const string& directory_path, bool do_prefix = true); // Abstract interface for classes that render DOT graphs. class GraphRendererInterface { diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 691c87457c6..5ac9ec56817 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -27,6 +27,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON) option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON) option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF) option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF) +option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) if (NOT WIN32) # Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option @@ -67,7 +68,15 @@ if(WIN32) endif() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") - set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-fno-exceptions -std=c++11") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -std=c++11") +endif() + +if (tensorflow_OPTIMIZE_FOR_NATIVE_ARCH) + include(CheckCXXCompilerFlag) + CHECK_CXX_COMPILER_FLAG("-march=native" COMPILER_OPT_ARCH_NATIVE_SUPPORTED) + if (COMPILER_OPT_ARCH_NATIVE_SUPPORTED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") + endif() endif() # External dependencies diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 0d8f4785a4b..34168342cf3 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -448,6 +448,23 @@ cuda_py_tests( tags = ["nomsan"], # disable to avoid false positives from scipy. ) +cuda_py_tests( + name = "vector_student_t_test", + size = "medium", + srcs = ["python/kernel_tests/vector_student_t_test.py"], + additional_deps = [ + ":distributions_py", + ":distributions_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_tests( name = "uniform_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index c6522069e9c..01896c52440 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -93,6 +93,11 @@ representing the posterior or posterior predictive. @@kl @@RegisterKL + +## Utilities + +@@softplus_inverse + """ from __future__ import absolute_import from __future__ import division @@ -110,6 +115,7 @@ from tensorflow.contrib.distributions.python.ops.dirichlet import * from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * from tensorflow.contrib.distributions.python.ops.distribution import * from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform +from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse from tensorflow.contrib.distributions.python.ops.exponential import * from tensorflow.contrib.distributions.python.ops.gamma import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index e7aa8165a7a..57c873f59e1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import math + import numpy as np from scipy import special @@ -28,8 +29,11 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -479,5 +483,96 @@ class GenNewSeedTest(test.TestCase): self.assertTrue(distribution_util.gen_new_seed(None, "salt") is None) +# TODO(jvdillon): Merge this test back into: +# tensorflow/python/kernel_tests/softplus_op_test.py +# once TF core is accepting new ops. +class SoftplusTest(test.TestCase): + + def _npSoftplus(self, np_features): + np_features = np.asarray(np_features) + zero = np.asarray(0).astype(np_features.dtype) + return np.logaddexp(zero, np_features) + + def _testSoftplus(self, np_features, use_gpu=False): + np_features = np.asarray(np_features) + np_softplus = self._npSoftplus(np_features) + with self.test_session(use_gpu=use_gpu) as sess: + softplus = nn_ops.softplus(np_features) + softplus_inverse = distribution_util.softplus_inverse(softplus) + [tf_softplus, tf_softplus_inverse] = sess.run([ + softplus, softplus_inverse]) + self.assertAllCloseAccordingToType(np_softplus, tf_softplus) + rtol = {"float16": 0.07, "float32": 0.003, "float64": 0.002}.get( + str(np_features.dtype), 1e-6) + # This will test that we correctly computed the inverse by verifying we + # recovered the original input. + self.assertAllCloseAccordingToType( + np_features, tf_softplus_inverse, + atol=0., rtol=rtol) + self.assertAllEqual(np.ones_like(tf_softplus).astype(np.bool), + tf_softplus > 0) + + self.assertShapeEqual(np_softplus, softplus) + self.assertShapeEqual(np_softplus, softplus_inverse) + + self.assertAllEqual(np.ones_like(tf_softplus).astype(np.bool), + np.isfinite(tf_softplus)) + self.assertAllEqual(np.ones_like(tf_softplus_inverse).astype(np.bool), + np.isfinite(tf_softplus_inverse)) + + def testNumbers(self): + for t in [np.float16, np.float32, np.float64]: + lower = {np.float16: -15, np.float32: -50, np.float64: -50}.get(t, -100) + upper = {np.float16: 50, np.float32: 50, np.float64: 50}.get(t, 100) + self._testSoftplus( + np.array(np.linspace(lower, upper, int(1e3)).astype(t)).reshape( + [2, -1]), + use_gpu=False) + self._testSoftplus( + np.array(np.linspace(lower, upper, int(1e3)).astype(t)).reshape( + [2, -1]), + use_gpu=True) + log_eps = np.log(np.finfo(t).eps) + one = t(1) + ten = t(10) + self._testSoftplus( + [ + log_eps, log_eps - one, log_eps + one, log_eps - ten, + log_eps + ten, -log_eps, -log_eps - one, -log_eps + one, + -log_eps - ten, -log_eps + ten + ], + use_gpu=False) + self._testSoftplus( + [ + log_eps, log_eps - one, log_eps + one, log_eps - ten, + log_eps + ten - log_eps, -log_eps - one, -log_eps + one, + -log_eps - ten, -log_eps + ten + ], + use_gpu=True) + + def testGradient(self): + with self.test_session(): + x = constant_op.constant( + [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], + shape=[2, 5], + name="x") + y = nn_ops.softplus(x, name="softplus") + x_init = np.asarray( + [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]], + dtype=np.float32, + order="F") + err = gradient_checker.compute_gradient_error( + x, [2, 5], y, [2, 5], x_init_value=x_init) + print("softplus (float) gradient err = ", err) + self.assertLess(err, 1e-4) + + def testInverseSoftplusGradientNeverNan(self): + with self.test_session(): + # Note that this range contains both zero and inf. + x = constant_op.constant((10.**np.arange(-8, 6)).astype(np.float16)) + y = distribution_util.softplus_inverse(x).eval() + # Equivalent to `assertAllFalse` (if it existed). + self.assertAllEqual(np.zeros_like(y).astype(np.bool), np.isnan(y)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py new file mode 100644 index 00000000000..72012059940 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py @@ -0,0 +1,283 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultivariateStudentsT Distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import linalg +from scipy import special + +from tensorflow.contrib.distributions.python.ops.vector_student_t import _VectorStudentT +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class _FakeVectorStudentT(object): + """Fake scipy implementation for Multivariate Student's t-distribution. + + Technically we don't need to test the `Vector Student's t-distribution` since + its composed of only unit-tested parts. However this _FakeVectorStudentT + serves as something like an end-to-end test of the + `TransformedDistribution + Affine` API. + + Other `Vector*` implementations need only test new code. That we don't need + to test every Vector* distribution is good because there aren't SciPy + analogues and reimplementing everything in NumPy sort of defeats the point of + having the `TransformedDistribution + Affine` API. + """ + + def __init__(self, df, shift, scale_tril): + self._df = np.asarray(df) + self._shift = np.asarray(shift) + self._scale_tril = np.asarray(scale_tril) + + def log_prob(self, x): + def _compute(df, shift, scale_tril, x): + k = scale_tril.shape[-1] + ildj = np.sum(np.log(np.abs(np.diag(scale_tril))), axis=-1) + logz = ildj + k * (0.5 * np.log(df) + + 0.5 * np.log(np.pi) + + special.gammaln(0.5 * df) - + special.gammaln(0.5 * (df + 1.))) + y = linalg.solve_triangular(scale_tril, np.matrix(x - shift).T, + lower=True, overwrite_b=True) + logs = -0.5 * (df + 1.) * np.sum(np.log1p(y**2. / df), axis=-2) + return logs - logz + if not self._df.shape: + return _compute(self._df, self._shift, self._scale_tril, x) + return np.concatenate([ + [_compute(self._df[i], self._shift[i], self._scale_tril[i], x[:, i, :])] + for i in range(len(self._df))]).T + + def prob(self, x): + return np.exp(self.log_prob(x)) + + +class VectorStudentTTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testProbStaticScalar(self): + with self.test_session(): + # Scalar batch_shape. + df = np.asarray(3., dtype=np.float32) + # Scalar batch_shape. + shift = np.asarray([1], dtype=np.float32) + scale_diag = np.asarray([2.], dtype=np.float32) + scale_tril = np.diag(scale_diag) + + expected_mst = _FakeVectorStudentT( + df=df, shift=shift, scale_tril=scale_tril) + + actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + validate_args=True) + x = 2. * self._rng.rand(4, 1).astype(np.float32) - 1. + + self.assertAllClose(expected_mst.log_prob(x), + actual_mst.log_prob(x).eval(), + rtol=0., atol=1e-5) + self.assertAllClose(expected_mst.prob(x), + actual_mst.prob(x).eval(), + rtol=0., atol=1e-5) + + def testProbStatic(self): + # Non-scalar batch_shape. + df = np.asarray([1., 2, 3], dtype=np.float32) + # Non-scalar batch_shape. + shift = np.asarray([[0., 0, 0], + [1, 2, 3], + [1, 0, 1]], + dtype=np.float32) + scale_diag = np.asarray([[1., 2, 3], + [2, 3, 4], + [4, 5, 6]], + dtype=np.float32) + scale_tril = np.concatenate([[np.diag(scale_diag[i])] + for i in range(len(scale_diag))]) + x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. + + expected_mst = _FakeVectorStudentT( + df=df, shift=shift, scale_tril=scale_tril) + + with self.test_session(): + actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + validate_args=True) + self.assertAllClose(expected_mst.log_prob(x), + actual_mst.log_prob(x).eval(), + rtol=0., atol=1e-5) + self.assertAllClose(expected_mst.prob(x), + actual_mst.prob(x).eval(), + rtol=0., atol=1e-5) + + def testProbDynamic(self): + # Non-scalar batch_shape. + df = np.asarray([1., 2, 3], dtype=np.float32) + # Non-scalar batch_shape. + shift = np.asarray([[0., 0, 0], + [1, 2, 3], + [1, 0, 1]], + dtype=np.float32) + scale_diag = np.asarray([[1., 2, 3], + [2, 3, 4], + [4, 5, 6]], + dtype=np.float32) + scale_tril = np.concatenate([[np.diag(scale_diag[i])] + for i in range(len(scale_diag))]) + x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. + + expected_mst = _FakeVectorStudentT( + df=df, shift=shift, scale_tril=scale_tril) + + with self.test_session(): + df_pl = array_ops.placeholder(dtypes.float32, name="df") + shift_pl = array_ops.placeholder(dtypes.float32, name="shift") + scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag") + feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag} + actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + validate_args=True) + self.assertAllClose(expected_mst.log_prob(x), + actual_mst.log_prob(x).eval(feed_dict=feed_dict), + rtol=0., atol=1e-5) + self.assertAllClose(expected_mst.prob(x), + actual_mst.prob(x).eval(feed_dict=feed_dict), + rtol=0., atol=1e-5) + + def testProbScalarBaseDistributionNonScalarTransform(self): + # Scalar batch_shape. + df = np.asarray(2., dtype=np.float32) + # Non-scalar batch_shape. + shift = np.asarray([[0., 0, 0], + [1, 2, 3], + [1, 0, 1]], + dtype=np.float32) + scale_diag = np.asarray([[1., 2, 3], + [2, 3, 4], + [4, 5, 6]], + dtype=np.float32) + scale_tril = np.concatenate([[np.diag(scale_diag[i])] + for i in range(len(scale_diag))]) + x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. + + expected_mst = _FakeVectorStudentT( + df=np.tile(df, len(scale_diag)), + shift=shift, + scale_tril=scale_tril) + + with self.test_session(): + actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + validate_args=True) + self.assertAllClose(expected_mst.log_prob(x), + actual_mst.log_prob(x).eval(), + rtol=0., atol=1e-5) + self.assertAllClose(expected_mst.prob(x), + actual_mst.prob(x).eval(), + rtol=0., atol=1e-5) + + def testProbScalarBaseDistributionNonScalarTransformDynamic(self): + # Scalar batch_shape. + df = np.asarray(2., dtype=np.float32) + # Non-scalar batch_shape. + shift = np.asarray([[0., 0, 0], + [1, 2, 3], + [1, 0, 1]], + dtype=np.float32) + scale_diag = np.asarray([[1., 2, 3], + [2, 3, 4], + [4, 5, 6]], + dtype=np.float32) + scale_tril = np.concatenate([[np.diag(scale_diag[i])] + for i in range(len(scale_diag))]) + x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. + + expected_mst = _FakeVectorStudentT( + df=np.tile(df, len(scale_diag)), + shift=shift, + scale_tril=scale_tril) + + with self.test_session(): + df_pl = array_ops.placeholder(dtypes.float32, name="df") + shift_pl = array_ops.placeholder(dtypes.float32, name="shift") + scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag") + feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag} + actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + validate_args=True) + self.assertAllClose(expected_mst.log_prob(x), + actual_mst.log_prob(x).eval(feed_dict=feed_dict), + rtol=0., atol=1e-5) + self.assertAllClose(expected_mst.prob(x), + actual_mst.prob(x).eval(feed_dict=feed_dict), + rtol=0., atol=1e-5) + + def testProbNonScalarBaseDistributionScalarTransform(self): + # Non-scalar batch_shape. + df = np.asarray([1., 2., 3.], dtype=np.float32) + # Scalar batch_shape. + shift = np.asarray([1, 2, 3], dtype=np.float32) + scale_diag = np.asarray([2, 3, 4], dtype=np.float32) + scale_tril = np.diag(scale_diag) + x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. + + expected_mst = _FakeVectorStudentT( + df=df, + shift=np.tile(shift[None, :], [len(df), 1]), + scale_tril=np.tile(scale_tril[None, :, :], [len(df), 1, 1])) + + with self.test_session(): + actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + validate_args=True) + self.assertAllClose(expected_mst.log_prob(x), + actual_mst.log_prob(x).eval(), + rtol=0., atol=1e-5) + self.assertAllClose(expected_mst.prob(x), + actual_mst.prob(x).eval(), + rtol=0., atol=1e-5) + + def testProbNonScalarBaseDistributionScalarTransformDynamic(self): + # Non-scalar batch_shape. + df = np.asarray([1., 2., 3.], dtype=np.float32) + # Scalar batch_shape. + shift = np.asarray([1, 2, 3], dtype=np.float32) + scale_diag = np.asarray([2, 3, 4], dtype=np.float32) + scale_tril = np.diag(scale_diag) + + x = 2. * self._rng.rand(4, 3, 3).astype(np.float32) - 1. + + expected_mst = _FakeVectorStudentT( + df=df, + shift=np.tile(shift[None, :], [len(df), 1]), + scale_tril=np.tile(scale_tril[None, :, :], [len(df), 1, 1])) + + with self.test_session(): + df_pl = array_ops.placeholder(dtypes.float32, name="df") + shift_pl = array_ops.placeholder(dtypes.float32, name="shift") + scale_diag_pl = array_ops.placeholder(dtypes.float32, name="scale_diag") + feed_dict = {df_pl: df, shift_pl: shift, scale_diag_pl: scale_diag} + actual_mst = _VectorStudentT(df=df, shift=shift, scale_diag=scale_diag, + validate_args=True) + self.assertAllClose(expected_mst.log_prob(x), + actual_mst.log_prob(x).eval(feed_dict=feed_dict), + rtol=0., atol=1e-5) + self.assertAllClose(expected_mst.prob(x), + actual_mst.prob(x).eval(feed_dict=feed_dict), + rtol=0., atol=1e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py index 2d1cc1d2b49..7e92f496773 100644 --- a/tensorflow/contrib/distributions/python/ops/bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijector.py @@ -53,10 +53,12 @@ import contextlib import itertools import math import re + import numpy as np import six from tensorflow.contrib import framework as contrib_framework +from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import operator_pd_cholesky from tensorflow.contrib.distributions.python.ops import operator_pd_diag from tensorflow.contrib.distributions.python.ops import operator_pd_identity @@ -75,6 +77,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops + __all__ = [ "Affine", "AffineLinearOperator", @@ -92,16 +95,6 @@ __all__ = [ ] -# TODO(jvdillon): deprecate this function once tf.expm1 exists. -def _expm1(x): - """Approximate exp{y}-1~=y for small |y|, and exp{y}-1 elsewhere.""" - # Recall, eps is smallest positive number such that 1 + eps != 1. - eps = np.finfo(x.dtype.base_dtype.as_numpy_dtype).eps - # Note we are careful to never send an NaN through ANY branch of where. - return array_ops.where(math_ops.less(math_ops.abs(x), eps), - x, math_ops.exp(x) - 1.) - - def _as_tensor(x, name): """Convenience to convert to `Tensor` or leave as `None`.""" return None if x is None else ops.convert_to_tensor(x, name=name) @@ -1271,7 +1264,7 @@ class PowerTransform(Bijector): return x, ildj # TODO(jvdillon): If large y accuracy is an issue, consider using # (y**self.power - 1.) / self.power when y >> 1. - x = _expm1(math_ops.log(y) * self.power) / self.power + x = math_ops.expm1(math_ops.log(y) * self.power) / self.power ildj = (self.power - 1.) * math_ops.reduce_sum( math_ops.log(y), reduction_indices=event_dims) @@ -1590,6 +1583,7 @@ class Affine(Bijector): `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Numeric `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. @@ -2086,31 +2080,21 @@ class Softplus(Bijector): return nn_ops.softplus(x) def _inverse_and_inverse_log_det_jacobian(self, y): - # The most stable inverse of softplus is not the most obvious one. - # y = softplus(x) = Log[1 + exp{x}], (which means y > 0). - # ==> exp{y} = 1 + exp{x} (1) - # ==> x = Log[exp{y} - 1] (2) - # = Log[(exp{y} - 1) / exp{y}] + Log[exp{y}] - # = Log[(1 - exp{-y}) / 1] + Log[exp{y}] - # = Log[1 - exp{-y}] + y (3) - # (2) is the "obvious" inverse, but (3) is more stable than (2) for large y. - # For small y (e.g. y = 1e-10), (3) will become -inf since 1 - exp{-y} will - # be zero. To fix this, we use 1 - exp{-y} approx y for small y > 0. - # - # Stable inverse log det jacobian. + if self.shaper is None: + raise ValueError("Jacobian cannot be computed with unknown event_ndims") + _, _, event_dims = self.shaper.get_dims(y) + # Could also do: + # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), + # reduction_indices=event_dims) + # but the following is more numerically stable. Ie, # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] # ==> dX/dY = exp{Y} / (exp{Y} - 1) # = 1 / (1 - exp{-Y}), # which is the most stable for large Y > 0. For small Y, we use # 1 - exp{-Y} approx Y. - if self.shaper is None: - raise ValueError("Jacobian cannot be computed with unknown event_ndims") - _, _, event_dims = self.shaper.get_dims(y) - log_one_minus_exp_neg = math_ops.log(-_expm1(-y)) - x = y + log_one_minus_exp_neg - ildj = -math_ops.reduce_sum( - log_one_minus_exp_neg, reduction_indices=event_dims) - return x, ildj + ildj = -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), + reduction_indices=event_dims) + return distribution_util.softplus_inverse(y), ildj def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument if self.shaper is None: diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 71e42bc2145..745e327e270 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -564,6 +564,63 @@ def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"): return y +# TODO(jvdillon): Merge this test back into: +# tensorflow/python/ops/softplus_op_test.py +# once TF core is accepting new ops. +def softplus_inverse(x, name=None): + """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). + + Mathematically this op is equivalent to: + + ```none + softplus_inverse = log(exp(x) - 1.) + ``` + + Args: + x: `Tensor`. Non-negative (not enforced), floating-point. + name: A name for the operation (optional). + + Returns: + `Tensor`. Has the same type/shape as input `x`. + """ + with ops.name_scope(name, "softplus_inverse", values=[x]): + x = ops.convert_to_tensor(x, name="x") + # We begin by deriving a more numerically stable softplus_inverse: + # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). + # ==> exp{x} = 1 + exp{y} (1) + # ==> y = Log[exp{x} - 1] (2) + # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] + # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] + # = Log[1 - exp{-x}] + x (3) + # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. + # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will + # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. + # + # In addition to the numerically stable derivation above, we clamp + # small/large values to be congruent with the logic in: + # tensorflow/core/kernels/softplus_op.h + # + # Finally, we set the input to one whenever the input is too large or too + # small. This ensures that no unchosen codepath is +/- inf. This is + # necessary to ensure the gradient doesn't get NaNs. Recall that the + # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` + # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful + # to overwrite `x` with ones only when we will never actually use this + # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. + threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2. + is_too_small = math_ops.less(x, np.exp(threshold)) + is_too_large = math_ops.greater(x, -threshold) + too_small_value = math_ops.log(x) + too_large_value = x + # This `where` will ultimately be a NOP because we won't select this + # codepath whenever we used the surrogate `ones_like`. + x = array_ops.where(math_ops.logical_or(is_too_small, is_too_large), + array_ops.ones_like(x), x) + y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x)) + return array_ops.where(is_too_small, too_small_value, + array_ops.where(is_too_large, too_large_value, y)) + + class AppendDocstring(object): """Helper class to promote private subclass docstring to public counterpart. diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py new file mode 100644 index 00000000000..af166800634 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -0,0 +1,295 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Vector Student's t distribution classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import bijector as bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import student_t +from tensorflow.contrib.distributions.python.ops import transformed_distribution +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops + + +# TODO(jvdillon): Add unittests for this once we know where will put this code +# and how it will generally be used. In the interim this code is tested via the +# _VectorStudentT tests. +def _infer_shapes(scale_oppd, shift): + """Helper which returns batch_shape, event_shape from `Affine` properties. + + The `Affine` `Bijector` (roughly) computes `Y = scale @ X.T + shift`. This + function infers the `batch_shape` and `event_shape` from the `scale` and + `shift` terms. + + Args: + scale_oppd: Instance of OperatorPDBase subclass representing the `Affine` + `Bijector` scale matrix. + shift: `Tensor` representing the `shift` vector. + + Returns: + batch_shape: 1D, integer `Tensor` representing the shape of batch + dimensions. + event_shape: 1D, integer `Tensor` representing the shape of event + dimensions. + + Raises: + ValueError: if we are not able to infer batch/event shapes from the args. + """ + # Collect known static shape. + def _has_static_ndims(x): + return x is not None and x.get_shape().ndims is not None + if _has_static_ndims(scale_oppd) and _has_static_ndims(shift): + batch_shape = scale_oppd.get_batch_shape().merge_with( + shift.get_shape()[:-1]) + event_shape = scale_oppd.get_shape()[-1:].merge_with( + shift.get_shape()[-1:]) + elif _has_static_ndims(scale_oppd): + batch_shape = scale_oppd.get_batch_shape() + event_shape = scale_oppd.get_shape()[-1:] + elif _has_static_ndims(shift): + batch_shape = shift.get_shape()[:-1] + event_shape = shift.get_shape()[-1:] + else: + batch_shape = tensor_shape.TensorShape(None) + event_shape = tensor_shape.TensorShape(None) + + # Convert TensorShape to Tensors and see if we're done. + if batch_shape.is_fully_defined(): + batch_shape = constant_op.constant(batch_shape.as_list(), + dtype=dtypes.int32) + else: + batch_shape = None + if event_shape.is_fully_defined(): + event_shape = constant_op.constant(event_shape.as_list(), + dtype=dtypes.int32) + else: + event_shape = None + if batch_shape is not None and event_shape is not None: + return batch_shape, event_shape + + # Collect known dynamic shape. + if scale_oppd is not None: + shape = scale_oppd.shape() + elif shift is not None: + shape = array_ops.shape(shift) + else: + raise ValueError("unable to infer batch_shape, event_shape") + + # Fill in what we don't know. + if batch_shape is None: + batch_shape = array_ops.identity(shape[:-1], name="batch_shape") + if event_shape is None: + event_shape = array_ops.identity(shape[-1:], name="event_shape") + + return batch_shape, event_shape + + +class _VectorStudentT(transformed_distribution.TransformedDistribution): + """A vector version of Student's t-distribution on `R^k`. + + #### Mathematical details + + Write `S` for the scale matrix (in R^{k x k}) and `mu` for the mean (in R^k). + The PDF of this distribution is: + + ```none + f(x) = (1 + y y.T / df)**(-0.5 (df + 1)) / Z + where, + y(x) = inv(S) (x - mu) + Z = abs(det(S)) ( sqrt(df pi) Gamma(0.5 df) / Gamma(0.5 (df + 1)) )**k + ``` + + Notice that the matrix `S` has semantics more similar to standard deviation + than covariance. + + This distribution is an Affine transformation of iid + [Student's t-distributions]( + https://en.wikipedia.org/wiki/Student%27s_t-distribution) + and should not be confused with the [Multivate Student's t-distribution]( + https://en.wikipedia.org/wiki/Multivariate_t-distribution). The + traditional Multivariate Student's t-distribution is type of + [elliptical distribution]( + https://en.wikipedia.org/wiki/Elliptical_distribution); it has PDF: + + ```none + f(x) = (1 + y y.T / df)**(-0.5 (df + k)) / Z + where, + y(x) = inv(S) (x - mu) + Z = abs(det(S)) sqrt(df pi)**k Gamma(0.5 df) / Gamma(0.5 (df + k)) + ``` + + Notice that the Multivariate Student's t-distribution uses `k` where the + Vector Student's t-distribution has a `1`. Conversely the Vector version has a + broader application of the power-`k` in the normalization. + + #### Examples + + A single instance of a "Vector Student's t-distribution" is defined by a mean + vector of of length `k` and a scale matrix of shape `k x k`. + + Extra leading dimensions, if provided, allow for batches. + + ```python + ds = tf.contrib.distributions + + # Initialize a single 3-variate vector Student's t-distribution. + mu = [1., 2, 3] + chol = [[1., 0, 0.], + [1, 3, 0], + [1, 2, 3]] + vt = ds.VectorStudentT(df=2, shift=mu, scale_tril=chol) + + # Evaluate this on an observation in R^3, returning a scalar. + vt.prob([-1., 0, 1]) + + # Initialize a batch of two 3-variate vector Student's t-distributions. + mu = [[1., 2, 3], + [11, 22, 33]] + chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. + vt = ds.VectorStudentT(shift=mu, scale_tril=chol) + + # Evaluate this on a two observations, each in R^3, returning a length two + # tensor. + x = [[-1, 0, 1], + [-11, 0, 11]] + vt.prob(x) + ``` + + For more examples of how to construct the `scale` matrix, see the + `bijector.Affine` docstring. + + """ + + def __init__(self, + df, + shift=None, + scale_identity_multiplier=None, + scale_diag=None, + scale_tril=None, + scale_perturb_factor=None, + scale_perturb_diag=None, + validate_args=False, + allow_nan_stats=True, + name="VectorStudentT"): + """Instantiates the vector Student's t-distributions on `R^k`. + + The `batch_shape` is the broadcast between `df.batch_shape` and + `Affine.batch_shape` where `Affine` is constructed from `shift` and + `scale_*` arguments. + + The `event_shape` is the event shape of `Affine.event_shape`. + + Args: + df: Numeric `Tensor`. The degrees of freedom of the distribution(s). + `df` must contain only positive values. + Must be scalar if `shift`, `scale_*` imply non-scalar batch_shape or + must have the same `batch_shape` implied by `shift`, `scale_*`. + shift: Numeric `Tensor`. If this is set to `None`, no `shift` is applied. + scale_identity_multiplier: floating point rank 0 `Tensor` representing a + scaling done to the identity matrix. + When `scale_identity_multiplier = scale_diag=scale_tril = None` then + `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added + to `scale`. + scale_diag: Numeric `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + When `None` no diagonal term is added to `scale`. + scale_tril: Numeric `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k + lower triangular matrix. + When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. + scale_perturb_factor: Numeric `Tensor` representing factor matrix with + last two dimensions of shape `(k, r)`. + When `None`, no rank-r update is added to `scale`. + scale_perturb_diag: Numeric `Tensor` representing the diagonal matrix. + `scale_perturb_diag` has shape [N1, N2, ... r], which represents an + r x r Diagonal matrix. + When `None` low rank updates will take the form `scale_perturb_factor * + scale_perturb_factor.T`. + validate_args: `Boolean`, default `False`. Whether to validate input + with asserts. If `validate_args` is `False`, and the inputs are + invalid, correct behavior is not guaranteed. + allow_nan_stats: `Boolean`, default `True`. If `False`, raise an + exception if a statistic (e.g. mean/mode/etc...) is undefined for any + batch member If `True`, batch members with valid parameters leading to + undefined statistics will return NaN for this statistic. + name: The name to give Ops created by the initializer. + """ + parameters = locals() + parameters.pop("self") + graph_parents = [df, shift, scale_identity_multiplier, scale_diag, + scale_tril, scale_perturb_factor, scale_perturb_diag] + with ops.name_scope(name) as ns: + with ops.name_scope("init", values=graph_parents): + # The shape of the _VectorStudentT distribution is governed by the + # relationship between df.batch_shape and affine.batch_shape. In + # pseudocode the basic procedure is: + # if df.batch_shape is scalar: + # if affine.batch_shape is not scalar: + # # broadcast self._distribution.sample so + # # it has affine.batch_shape. + # self.batch_shape = affine.batch_shape + # else: + # if affine.batch_shape is scalar: + # # let affine broadcasting do its thing. + # self.batch_shape = df.batch_shape + # All of the above magic is actually handled by TransformedDistribution. + # Here we really only need to collect the affine.batch_shape and decide + # what we're going to pass in to TransformedDistribution's + # (override) batch_shape arg. + self._distribution = student_t.StudentT(df=df, mu=0., sigma=1.) + self._affine = bijectors.Affine( + shift=shift, + scale_identity_multiplier=scale_identity_multiplier, + scale_diag=scale_diag, + scale_tril=scale_tril, + scale_perturb_factor=scale_perturb_factor, + scale_perturb_diag=scale_perturb_diag, + validate_args=validate_args) + self._batch_shape, self._override_event_shape = _infer_shapes( + self.scale, self.shift) + self._override_batch_shape = distribution_util.pick_vector( + self._distribution.is_scalar_batch(), + self._batch_shape, + constant_op.constant([], dtype=dtypes.int32)) + super(_VectorStudentT, self).__init__( + distribution=self._distribution, + bijector=self._affine, + batch_shape=self._override_batch_shape, + event_shape=self._override_event_shape, + validate_args=validate_args, + name=ns) + + @property + def df(self): + """Degrees of freedom in these Student's t distribution(s).""" + return self._distribution.df + + @property + def shift(self): + """Locations of these Student's t distribution(s).""" + return self._affine.shift + + @property + def scale(self): + """Dense (batch) covariance matrix, if available.""" + return self._affine.scale diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index d4ce003cdc3..f3134c0d699 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -229,7 +229,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(keys_sparse, output) with self.test_session(): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[keys_sparse].indices.eval(), @@ -247,7 +247,7 @@ class TransformerTest(test.TestCase): output = feature_column_ops._Transformer(features).transform(keys_sparse) with self.test_session(): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() # While the input is a dense Tensor, the output should be a SparseTensor. self.assertIsInstance(output, sparse_tensor.SparseTensor) self.assertEqual(output.dtype, dtypes.int64) @@ -316,7 +316,7 @@ class TransformerTest(test.TestCase): self.assertIn(weighted_ids, output) with self.test_session(): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(), ids_tensor.dense_shape.eval()) self.assertAllEqual(output[weighted_ids][0].indices.eval(), @@ -346,7 +346,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -368,7 +368,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -392,7 +392,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -414,7 +414,7 @@ class TransformerTest(test.TestCase): self.assertEqual(len(output), 1) self.assertIn(vocab_sparse, output) with self.test_session(): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64) self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1]) self.assertAllEqual(output[vocab_sparse].indices.eval(), @@ -584,7 +584,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): ]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [3, 3 + 4 + 10]) def testRealValuedColumn(self): @@ -681,7 +681,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_column]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]], output.eval()) @@ -699,7 +699,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]], output.eval()) @@ -717,7 +717,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], output.eval()) @@ -751,7 +751,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [one_hot_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual([3, 10], output.eval().shape) def testEmbeddingColumnSucceedsForDNN(self): @@ -857,7 +857,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual(output.eval().shape, [2, 10]) def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self): @@ -908,7 +908,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "Error creating input layer for column: ids_weighted_by_weights"): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() feature_column_ops.input_from_feature_columns(features, [weighted_ids]) def testCrossedColumnFailsForDNN(self): @@ -1015,7 +1015,7 @@ class CreateInputLayersForDNNsTest(test.TestCase): [embeded_sparse]) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() # score: (sum of weights) self.assertAllEqual(output.eval(), [[10.], [50.], [0.]]) @@ -1208,7 +1208,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = np.array([4, 3, 4]) @@ -1242,7 +1242,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = np.array([4, 3, hash_buckets]) @@ -1272,7 +1272,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) self.assertAllEqual(expected_input_shape, model_input.shape) @@ -1302,7 +1302,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): embedding_weights) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() model_input, gradients = sess.run([model_input_tensor, gradient_tensor]) expected_input_shape = [4, 3, embedding_dimension] @@ -1369,7 +1369,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() model_input = sess.run(model_input_tensor) expected_input_shape = [ @@ -1437,7 +1437,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_ids], num_outputs=5) with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) def testWeightedSparseColumnWithDenseInputTensor(self): @@ -1453,7 +1453,7 @@ class WeightedSumTest(test.TestCase): with self.test_session(): variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual(logits.eval().shape, [2, 5]) def testCrossedColumn(self): @@ -1507,7 +1507,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) with self.test_session() as sess: variables_lib.initialize_all_variables().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[movies][0] self.assertEqual(weights.get_shape(), (3, 1)) @@ -1582,7 +1582,7 @@ class WeightedSumTest(test.TestCase): features, [age, language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1622,7 +1622,7 @@ class WeightedSumTest(test.TestCase): self.assertEqual(len(variables), 1) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1686,7 +1686,7 @@ class WeightedSumTest(test.TestCase): features, [weighted_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllClose(output.eval(), [[0.], [0.]]) @@ -1714,7 +1714,7 @@ class WeightedSumTest(test.TestCase): features, [language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() # score: 0.1 + language_weight['hindi'] + language_weight['english'] sess.run(bias.assign([0.1])) @@ -1737,7 +1737,7 @@ class WeightedSumTest(test.TestCase): features, [movies], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[movies][0] self.assertEqual(weights.get_shape(), (15, 1)) @@ -1771,7 +1771,7 @@ class WeightedSumTest(test.TestCase): features, [country_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[country_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1795,7 +1795,7 @@ class WeightedSumTest(test.TestCase): features, [language_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[language_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1828,7 +1828,7 @@ class WeightedSumTest(test.TestCase): features, [country_language], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[country_language][0] sess.run(weights.assign(weights + 0.4)) @@ -1869,7 +1869,7 @@ class WeightedSumTest(test.TestCase): scope=scope)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertEqual(2, len(column_to_variable[country])) self.assertEqual(3, len(column_to_variable[language])) @@ -1906,7 +1906,7 @@ class WeightedSumTest(test.TestCase): features, [country, age, incomes], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() incomes_weights = column_to_variable[incomes][0] sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]])) @@ -1943,7 +1943,7 @@ class WeightedSumTest(test.TestCase): features, [country, age, height, incomes], num_outputs=5)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() height_weights = column_to_variable[height][0] sess.run( @@ -1973,7 +1973,7 @@ class WeightedSumTest(test.TestCase): features, [bucket], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3], [0.4]])) @@ -2001,7 +2001,7 @@ class WeightedSumTest(test.TestCase): features, [bucket, country], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() # dimension = 2, bucket_size = 4, num_classes = 1 sess.run(column_to_variable[bucket][0].assign( @@ -2030,7 +2030,7 @@ class WeightedSumTest(test.TestCase): features, [bucket, country], num_outputs=5)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() # dimension = 2, bucket_size = 4, num_classes = 5 sess.run(column_to_variable[bucket][0].assign( @@ -2066,7 +2066,7 @@ class WeightedSumTest(test.TestCase): features, [country_price], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[country_price][0] sess.run(weights.assign(weights + 0.4)) @@ -2105,7 +2105,7 @@ class WeightedSumTest(test.TestCase): features, [country_language_price], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[country_language_price][0] sess.run(weights.assign(weights + 0.4)) @@ -2129,7 +2129,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2144,7 +2144,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2159,7 +2159,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.6], [0.7]]) @@ -2180,7 +2180,7 @@ class WeightedSumTest(test.TestCase): features, [product], num_outputs=1)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() product_weights = column_to_variable[product][0] sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]])) self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]]) @@ -2192,7 +2192,7 @@ class WeightedSumTest(test.TestCase): features, [feature_column.real_valued_column("age")], num_outputs=3) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() sess.run(bias.assign([0.1, 0.2, 0.3])) self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]) @@ -2206,7 +2206,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (1, 3)) sess.run(weights.assign([[0.01, 0.03, 0.05]])) @@ -2230,7 +2230,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) sess.run( @@ -2256,7 +2256,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2296,7 +2296,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2325,7 +2325,7 @@ class WeightedSumTest(test.TestCase): features, [column], num_outputs=3)) with self.test_session() as sess: variables_lib.global_variables_initializer().run() - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() weights = column_to_variable[column][0] self.assertEqual(weights.get_shape(), (5, 3)) @@ -2390,7 +2390,7 @@ class ParseExampleTest(test.TestCase): self.assertIn(bucket, output) self.assertIn(wire_cast, output) with self.test_session(): - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]]) self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]]) self.assertAllEqual(output[wire_cast].values.eval(), [2, 0]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 197806606fa..955b57e893f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -160,7 +160,7 @@ class DynamicRnnEstimatorTest(test.TestCase): self.context_feature_columns) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.initialize_all_tables()) + sess.run(data_flow_ops.tables_initializer()) sequence_input_val = sess.run(sequence_input) expected_shape = np.array([ 3, # expected batch size @@ -181,7 +181,7 @@ class DynamicRnnEstimatorTest(test.TestCase): # Obtain values of activations and final state. with session.Session() as sess: sess.run(variables.global_variables_initializer()) - sess.run(data_flow_ops.initialize_all_tables()) + sess.run(data_flow_ops.tables_initializer()) activations, final_state = sess.run([activations_t, final_state_t]) expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index e3dc27e6460..1d363897228 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1283,7 +1283,7 @@ class Estimator(BaseEstimator): with tf_session.Session('') as session: variables.initialize_local_variables() - data_flow_ops.initialize_all_tables() + data_flow_ops.tables_initializer() saver_for_restore = saver.Saver( variables.global_variables(), sharded=True) @@ -1291,7 +1291,7 @@ class Estimator(BaseEstimator): init_op = control_flow_ops.group( variables.local_variables_initializer(), - data_flow_ops.initialize_all_tables()) + data_flow_ops.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 953a26d3f95..956b04ea029 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -22,8 +22,6 @@ import numpy as np from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.contrib.framework.python.ops import variables -from tensorflow.contrib.learn.python.learn import evaluable -from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps from tensorflow.python.framework import ops @@ -42,7 +40,7 @@ KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT # TODO(agarwal,ands): support sharded input. -class KMeansClustering(evaluable.Evaluable, trainable.Trainable): +class KMeansClustering(estimator.Estimator): """An Estimator fo rK-Means clustering.""" SCORES = 'scores' CLUSTER_IDX = 'cluster_idx' @@ -58,6 +56,7 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable): random_seed=0, use_mini_batch=True, kmeans_plus_plus_num_retries=2, + relative_tolerance=None, config=None): """Creates a model for running KMeans training and inference. @@ -76,6 +75,9 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable): additional points to draw from the current distribution before selecting the best. If a negative value is specified, a heuristic is used to sample O(log(num_to_sample)) additional points. + relative_tolerance: A relative tolerance of change in the loss between + iterations. Stops learning if the loss changes less than this amount. + Note that this may not work correctly if use_mini_batch=True. config: See Estimator """ self._num_clusters = num_clusters @@ -84,7 +86,8 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable): self._random_seed = random_seed self._use_mini_batch = use_mini_batch self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries - self._estimator = estimator.Estimator( + self._relative_tolerance = relative_tolerance + super(KMeansClustering, self).__init__( model_fn=self._get_model_function(), model_dir=model_dir) class LossRelativeChangeHook(session_run_hook.SessionRunHook): @@ -119,76 +122,13 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable): run_context.request_stop() self._prev_loss = loss - @property - def model_dir(self): - """See Evaluable.""" - return self._estimator.model_dir - - def fit(self, - input_fn=None, - steps=None, - monitors=None, - max_steps=None, - relative_tolerance=None): - """Trains a k-means clustering on x. - - Note: See Estimator for logic for continuous training and graph - construction across multiple calls to fit. - - Args: - input_fn: see Trainable.fit. - steps: see Trainable.fit. - monitors: see Trainable.fit. - max_steps: see Trainable.fit. - relative_tolerance: A relative tolerance of change in the loss between - iterations. Stops learning if the loss changes less than this amount. - Note that this may not work correctly if use_mini_batch=True. - - Returns: - Returns self. - """ - if relative_tolerance is not None: - if monitors is None: - monitors = [] - monitors.append(self.LossRelativeChangeHook(relative_tolerance)) - # Make sure that we will eventually terminate. - assert ((monitors is not None and len(monitors)) or (steps is not None) or - (max_steps is not None)) - self._estimator.fit(input_fn=input_fn, - steps=steps, - max_steps=max_steps, - monitors=monitors) - return self - - def evaluate(self, - input_fn=None, - feed_fn=None, - steps=None, - metrics=None, - name=None, - checkpoint_path=None, - hooks=None): - """See Evaluable.evaluate.""" - return self._estimator.evaluate( - input_fn=input_fn, - feed_fn=feed_fn, - steps=steps, - metrics=metrics, - name=name, - checkpoint_path=checkpoint_path, - hooks=hooks) - - def predict(self, input_fn=None, outputs=None, as_iterable=False): - """See BaseEstimator.predict.""" - - outputs = outputs or [KMeansClustering.CLUSTER_IDX] - assert isinstance(outputs, list) - results = self._estimator.predict( - input_fn=input_fn, outputs=outputs, as_iterable=as_iterable) - if len(outputs) == 1 and not as_iterable: - return results[outputs[0]] - else: - return results + def predict_cluster_idx(self, input_fn=None): + """Yields predicted cluster indices.""" + key = KMeansClustering.CLUSTER_IDX + results = super(KMeansClustering, self).predict( + input_fn=input_fn, outputs=[key]) + for result in results: + yield result[key] def score(self, input_fn=None, steps=None): """Predict total sum of distances to nearest clusters. @@ -223,14 +163,19 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable): Array with same number of rows as x, and num_clusters columns, containing distances to the cluster centers. """ - return self.predict( + key = KMeansClustering.ALL_SCORES + results = super(KMeansClustering, self).predict( input_fn=input_fn, - outputs=[KMeansClustering.ALL_SCORES], + outputs=[key], as_iterable=as_iterable) + if not as_iterable: + return results[key] + else: + return results def clusters(self): """Returns cluster centers.""" - return self._estimator.get_variable_value(self.CLUSTERS) + return super(KMeansClustering, self).get_variable_value(self.CLUSTERS) def _parse_tensor_or_dict(self, features): if isinstance(features, dict): @@ -264,11 +209,16 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable): KMeansClustering.CLUSTER_IDX: model_predictions[0], } eval_metric_ops = {KMeansClustering.SCORES: loss,} + if self._relative_tolerance is not None: + training_hooks = [self.LossRelativeChangeHook(self._relative_tolerance)] + else: + training_hooks = None return ModelFnOps( mode=mode, predictions=predictions, eval_metric_ops=eval_metric_ops, loss=loss, - train_op=training_op) + train_op=training_op, + training_hooks=training_hooks) return _model_fn diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py index fb7c21c13a9..8364a57f326 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py @@ -40,6 +40,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import benchmark from tensorflow.python.platform import flags from tensorflow.python.platform import test +from tensorflow.python.training import input as input_lib FLAGS = flags.FLAGS @@ -68,7 +69,7 @@ def make_random_points(centers, num_points, max_offset=20): class KMeansTestBase(test.TestCase): - def input_fn(self, batch_size=None, points=None): + def input_fn(self, batch_size=None, points=None, num_epochs=None): """Returns an input_fn that randomly selects batches from given points.""" batch_size = batch_size or self.batch_size points = points if points is not None else self.points @@ -77,14 +78,15 @@ class KMeansTestBase(test.TestCase): def _fn(): x = constant_op.constant(points) if batch_size == num_points: - return x, None + return input_lib.limit_epochs(x, num_epochs=num_epochs), None indices = random_ops.random_uniform( constant_op.constant([batch_size]), minval=0, maxval=num_points - 1, dtype=dtypes.int32, seed=10) - return array_ops.gather(x, indices), None + batch = array_ops.gather(x, indices) + return (input_lib.limit_epochs(batch, num_epochs=num_epochs), None) return _fn @@ -113,21 +115,23 @@ class KMeansTest(KMeansTestBase): self.num_points) self.true_score = np.add.reduce(self.scores) - self.kmeans = kmeans_lib.KMeansClustering( + def _kmeans(self, relative_tolerance=None): + return kmeans_lib.KMeansClustering( self.num_centers, initial_clusters=factorization.RANDOM_INIT, use_mini_batch=self.use_mini_batch, config=self.config(14), - random_seed=10) + random_seed=10, + relative_tolerance=relative_tolerance) def test_clusters(self): - kmeans = self.kmeans + kmeans = self._kmeans() kmeans.fit(input_fn=self.input_fn(), steps=1) clusters = kmeans.clusters() self.assertAllEqual(list(clusters.shape), [self.num_centers, self.num_dims]) def test_fit(self): - kmeans = self.kmeans + kmeans = self._kmeans() kmeans.fit(input_fn=self.input_fn(), steps=1) score1 = kmeans.score( input_fn=self.input_fn(batch_size=self.num_points), steps=1) @@ -146,20 +150,20 @@ class KMeansTest(KMeansTestBase): initial_clusters=factorization.RANDOM_INIT, use_mini_batch=self.use_mini_batch, config=run_config.RunConfig(tf_random_seed=14), - random_seed=12) + random_seed=12, + relative_tolerance=1e-4) kmeans.fit( input_fn=self.input_fn(), - # Force it to train forever until the monitor stops it. - steps=None, - relative_tolerance=1e-4) + # Force it to train until the relative tolerance monitor stops it. + steps=None) score = kmeans.score( input_fn=self.input_fn(batch_size=self.num_points), steps=1) self.assertNear(self.true_score, score, self.true_score * 0.005) def test_infer(self): - kmeans = self.kmeans - kmeans.fit(input_fn=self.input_fn(), relative_tolerance=1e-4) + kmeans = self._kmeans(relative_tolerance=1e-4) + kmeans.fit(input_fn=self.input_fn()) clusters = kmeans.clusters() # Make a small test set @@ -167,8 +171,8 @@ class KMeansTest(KMeansTestBase): points, true_assignments, true_offsets = make_random_points(clusters, num_points) # Test predict - assignments = kmeans.predict(input_fn=self.input_fn( - batch_size=num_points, points=points)) + assignments = list(kmeans.predict_cluster_idx(input_fn=self.input_fn( + batch_size=num_points, points=points, num_epochs=1))) self.assertAllEqual(assignments, true_assignments) # Test score @@ -260,7 +264,8 @@ class KMeansTestCosineDistance(KMeansTestBase): centers, axis=0), np.sort( self.true_centers, axis=0), atol=1e-2) - assignments = self.kmeans.predict(input_fn=self.input_fn()) + assignments = list(self.kmeans.predict_cluster_idx( + input_fn=self.input_fn(num_epochs=1))) self.assertAllClose( centers[assignments], self.true_centers[self.true_assignments], @@ -305,8 +310,11 @@ class KMeansTestCosineDistance(KMeansTestBase): self.assertAllClose( sorted(centers.tolist()), sorted(true_centers.tolist()), atol=1e-2) - assignments = kmeans.predict( - input_fn=lambda: (constant_op.constant(points), None)) + def _input_fn(): + return ( + input_lib.limit_epochs(constant_op.constant(points), num_epochs=1), + None) + assignments = list(kmeans.predict_cluster_idx(input_fn=_input_fn)) self.assertAllClose( centers[assignments], true_centers[true_assignments], atol=1e-2) diff --git a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py index 8a6d0ef0183..7dc26781f94 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py @@ -25,3 +25,4 @@ class PredictionKey(object): LOGISTIC = "logistic" SCORES = "scores" TOP_K = "top_k" + GENERIC = "output" diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 82984d87ed2..45a2dc18469 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -634,7 +634,7 @@ def _get_local_init_op(): ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: op_list = [variables.local_variables_initializer(), - data_flow_ops.initialize_all_tables()] + data_flow_ops.tables_initializer()] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) @@ -881,7 +881,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None): else: session.run(variables.global_variables_initializer()) session.run(variables.local_variables_initializer()) - session.run(data_flow_ops.initialize_all_tables()) + session.run(data_flow_ops.tables_initializer()) coord = coordinator.Coordinator() threads = None try: diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 7ab6aafdf39..d90b3d3f850 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -156,12 +156,12 @@ def read_keyed_batch_examples(file_pattern, file_pattern, batch_size, reader, - randomize_input, - num_epochs, - queue_capacity, - num_threads, - read_batch_size, - parse_fn, + randomize_input=randomize_input, + num_epochs=num_epochs, + queue_capacity=queue_capacity, + num_threads=num_threads, + read_batch_size=read_batch_size, + parse_fn=parse_fn, setup_shared_queue=False, name=name) @@ -225,12 +225,12 @@ def _read_keyed_batch_examples_shared_queue(file_pattern, file_pattern, batch_size, reader, - randomize_input, - num_epochs, - queue_capacity, - num_threads, - read_batch_size, - parse_fn, + randomize_input=randomize_input, + num_epochs=num_epochs, + queue_capacity=queue_capacity, + num_threads=num_threads, + read_batch_size=read_batch_size, + parse_fn=parse_fn, setup_shared_queue=True, name=name) @@ -265,7 +265,7 @@ def _get_file_names(file_pattern, randomize_input): def _get_examples(file_name_queue, reader, num_threads, read_batch_size, - parse_fn): + filter_fn, parse_fn): with ops.name_scope('read'): example_list = [] for _ in range(num_threads): @@ -274,6 +274,10 @@ def _get_examples(file_name_queue, reader, num_threads, read_batch_size, read_batch_size) else: keys, examples_proto = reader().read(file_name_queue) + if filter_fn: + mask = filter_fn(keys, examples_proto) + keys = array_ops.boolean_mask(keys, mask) + examples_proto = array_ops.boolean_mask(examples_proto, mask) if parse_fn: parsed_examples = parse_fn(examples_proto) # Map keys into example map because batch_join doesn't support @@ -296,6 +300,7 @@ def _read_keyed_batch_examples_helper(file_pattern, queue_capacity=10000, num_threads=1, read_batch_size=1, + filter_fn=None, parse_fn=None, setup_shared_queue=False, name=None): @@ -316,6 +321,9 @@ def _read_keyed_batch_examples_helper(file_pattern, num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of records to read at once + filter_fn: Filtering function, takes both keys as well `Example` Tensors + and returns a boolean mask of the same shape as the input Tensors to + be applied for filtering. If `None`, no filtering is done. parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. setup_shared_queue: Whether to set up a shared queue for file names. @@ -366,7 +374,7 @@ def _read_keyed_batch_examples_helper(file_pattern, name=file_name_queue_scope) example_list = _get_examples(file_name_queue, reader, num_threads, - read_batch_size, parse_fn) + read_batch_size, filter_fn, parse_fn) enqueue_many = read_batch_size > 1 diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 98d5c5b5009..90d58dec14e 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -39,6 +39,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import io_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -681,6 +682,67 @@ class GraphIOTest(test.TestCase): coord.request_stop() coord.join(threads) + def test_keyed_features_filter(self): + gfile.Glob = self._orig_glob + lines = [ + '{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}', + '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', + '{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}', + '{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}', + '{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}', + '{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}' + ] + filename = self._create_temp_file("\n".join(lines)) + + batch_size = 2 + queue_capacity = 4 + name = "my_batch" + features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)} + + def filter_fn(keys, examples_json): + del keys + serialized = parsing_ops.decode_json_example(examples_json) + examples = parsing_ops.parse_example(serialized, features) + return math_ops.less(examples["age"], 2) + + with ops.Graph().as_default() as g, self.test_session(graph=g) as session: + keys, inputs = graph_io._read_keyed_batch_examples_helper( + filename, + batch_size, + reader=io_ops.TextLineReader, + randomize_input=False, + num_epochs=1, + read_batch_size=batch_size, + queue_capacity=queue_capacity, + filter_fn=filter_fn, + name=name) + self.assertAllEqual((None,), keys.get_shape().as_list()) + self.assertAllEqual((None,), inputs.get_shape().as_list()) + session.run(variables.local_variables_initializer()) + + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + # First batch of two filtered examples. + out_keys, out_vals = session.run((keys, inputs)) + self.assertAllEqual( + [filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"], + out_keys) + self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")], + out_vals) + + # Second batch will only have one filtered example as that's the only + # remaining example that satisfies the filtering criterion. + out_keys, out_vals = session.run((keys, inputs)) + self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys) + self.assertAllEqual([lines[3].encode("utf-8")], out_vals) + + # Exhausted input. + with self.assertRaises(errors.OutOfRangeError): + session.run((keys, inputs)) + + coord.request_stop() + coord.join(threads) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py index 5b0000afc7b..e3cf27ebab4 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner.py @@ -88,20 +88,20 @@ def run(experiment_fn, output_dir, schedule=None): # Execute the schedule if not hasattr(experiment, schedule): logging.error('Schedule references non-existent task %s', schedule) - valid_tasks = [x for x in experiment.__dict__ - if callable(getattr(experiment, x))] + valid_tasks = [x for x in dir(experiment) + if not x.startswith('_') + and callable(getattr(experiment, x))] logging.error('Allowed values for this experiment are: %s', valid_tasks) - raise ValueError('Schedule references non-existent task %s', schedule) + raise ValueError('Schedule references non-existent task %s' % schedule) task = getattr(experiment, schedule) if not callable(task): logging.error('Schedule references non-callable member %s', schedule) - valid_tasks = [ - x for x in experiment.__dict__ - if callable(getattr(experiment, x)) and not x.startswith('_') - ] + valid_tasks = [x for x in dir(experiment) + if not x.startswith('_') + and callable(getattr(experiment, x))] logging.error('Allowed values for this experiment are: %s', valid_tasks) - raise TypeError('Schedule references non-callable member %s', schedule) + raise TypeError('Schedule references non-callable member %s' % schedule) return task() diff --git a/tensorflow/contrib/learn/python/learn/learn_runner_test.py b/tensorflow/contrib/learn/python/learn/learn_runner_test.py index 5404d26fead..31e0dd561d2 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner_test.py @@ -27,9 +27,11 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import ctypes sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) +from tensorflow.contrib.learn.python.learn import evaluable # pylint: disable=g-import-not-at-top from tensorflow.contrib.learn.python.learn import experiment from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.contrib.learn.python.learn import run_config +from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -43,13 +45,22 @@ class TestExperiment(experiment.Experiment): self.default = default self.config = config - @property - def estimator(self): - - class Estimator(object): + class Estimator(evaluable.Evaluable, trainable.Trainable): config = self.config - return Estimator() + def model_dir(self): + raise NotImplementedError + + def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, + monitors=None, max_steps=None): + raise NotImplementedError + + def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None, + batch_size=None, steps=None, metrics=None, name=None, + checkpoint_path=None, hooks=None): + raise NotImplementedError + + super(TestExperiment, self).__init__(Estimator(), None, None) def local_run(self): return "local_run" diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 0d39d26c3cd..e0452c56a2f 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -66,13 +66,13 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, with graph.as_default(): with tf_session.Session('') as session: variables.local_variables_initializer() - data_flow_ops.initialize_all_tables() + data_flow_ops.tables_initializer() saver.restore(session, checkpoint_path) export = exporter.Exporter(saver) export.init(init_op=control_flow_ops.group( variables.local_variables_initializer(), - data_flow_ops.initialize_all_tables()), + data_flow_ops.tables_initializer()), default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection( diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py index 2cb7173d5a0..18bfdc61c6c 100644 --- a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py @@ -29,11 +29,12 @@ from tensorflow.python.ops import parsing_ops # A return type allowing input_fns to return multiple values in a well- # defined way (analogous to ModelFnOps). # The expected return values are: -# features: a dict of string to Tensor, giving the features to be passed to -# the model. -# labels: a dict of string to Tensor, giving labels (aka targets) for training. -# default_inputs: a dict of string to Tensor, giving the input Tensors (if -# any) that this input_fn expects to be fed. +# features: a dict of string to `Tensor` or `SparseTensor`, giving the features +# to be passed to the model. +# labels: a dict of string to `Tensor` or `SparseTensor`, giving labels (aka +# targets) for training. +# default_inputs: a dict of string to `Tensor` or `SparseTensor`, giving the +# input placeholders (if any) that this input_fn expects to be fed. InputFnOps = collections.namedtuple('InputFnOps', ['features', 'labels', diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 5d7ba38446a..9e452d09056 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy import os import re import time @@ -128,10 +127,15 @@ def get_input_alternatives(input_ops): if not features: raise ValueError('Features must be defined.') + # TODO(b/34253951): reinstate the "features" input_signature. + # The "features" input_signature, as written, does not work with + # SparseTensors. It is simply commented out as a stopgap, pending discussion + # on the bug as to the correct solution. + # Add the "features" input_signature in any case. # Note defensive copy because model_fns alter the features dict. - input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = ( - copy.copy(features)) + # input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = ( + # copy.copy(features)) return input_alternatives, features @@ -163,6 +167,8 @@ def get_output_alternatives( # interpret the model as single-headed of unknown type. default_problem_type = constants.ProblemType.UNSPECIFIED default_outputs = model_fn_ops.predictions + if not isinstance(default_outputs, dict): + default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs} actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY output_alternatives = {actual_default_output_alternative_key: (default_problem_type, default_outputs)} @@ -182,9 +188,10 @@ def build_all_signature_defs(input_alternatives, output_alternatives, in output_alternatives.items()} # Add the default SignatureDef - default_inputs = input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] + default_inputs = input_alternatives.get(DEFAULT_INPUT_ALTERNATIVE_KEY) if not default_inputs: - default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] + raise ValueError('A default input_alternative must be provided.') + # default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] # default outputs are guaranteed to exist above (default_problem_type, default_outputs) = ( output_alternatives[actual_default_output_alternative_key]) @@ -252,7 +259,7 @@ def garbage_collect_exports(export_dir_base, exports_to_keep): def make_export_strategy(export_input_fn, default_output_alternative_key='default', assets_extra=None, - export_as_text=False, + as_text=False, exports_to_keep=None): """Create an ExportStrategy for use with Experiment.""" @@ -263,7 +270,7 @@ def make_export_strategy(export_input_fn, export_input_fn, default_output_alternative_key=default_output_alternative_key, assets_extra=assets_extra, - export_as_text=export_as_text, + as_text=as_text, exports_to_keep=exports_to_keep) garbage_collect_exports(export_dir_base, exports_to_keep) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 60ff1bc318a..955e14ae44f 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -22,11 +22,14 @@ import sys import tempfile import time -# TODO: #6568 Remove this hack that makes dlopen() not crash. +# pylint: disable=g-import-not-at-top + +# TODO(jart): #6568 Remove this hack that makes dlopen() not crash. if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import ctypes sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) +from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.utils import input_fn_utils @@ -87,9 +90,9 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertEqual(input_alternatives[ saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY], "bogus default input dict") - self.assertEqual(input_alternatives[ - saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY], - "bogus features dict") + # self.assertEqual(input_alternatives[ + # saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY], + # "bogus features dict") def test_get_output_alternatives_explicit(self): provided_output_alternatives = { @@ -122,6 +125,21 @@ class SavedModelExportUtilsTest(test.TestCase): }) }, output_alternatives) + def test_get_output_alternatives_implicit_single(self): + prediction_tensor = constant_op.constant(["bogus"]) + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions=prediction_tensor, + output_alternatives=None) + + output_alternatives, _ = saved_model_export_utils.get_output_alternatives( + model_fn_ops) + self.assertEqual({ + "default_output_alternative": (constants.ProblemType.UNSPECIFIED, { + "output": prediction_tensor + }) + }, output_alternatives) + def test_build_all_signature_defs(self): input_features = constant_op.constant(["10"]) input_example = constant_op.constant(["11"]) @@ -168,20 +186,56 @@ class SavedModelExportUtilsTest(test.TestCase): signature_def_utils.predict_signature_def({ "input": input_example }, {"output": output_3}), - "features_input_alternative:head-1": - signature_def_utils.regression_signature_def(input_features, - output_1), - "features_input_alternative:head-2": - signature_def_utils.classification_signature_def(input_features, - output_2, None), - "features_input_alternative:head-3": - signature_def_utils.predict_signature_def({ - "input": input_features - }, {"output": output_3}), + # "features_input_alternative:head-1": + # signature_def_utils.regression_signature_def(input_features, + # output_1), + # "features_input_alternative:head-2": + # signature_def_utils.classification_signature_def(input_features, + # output_2, None), + # "features_input_alternative:head-3": + # signature_def_utils.predict_signature_def({ + # "input": input_features + # }, {"output": output_3}), } self.assertDictEqual(expected_signature_defs, signature_defs) + def test_build_all_signature_defs_legacy_input_fn_not_supported(self): + """Tests that legacy input_fn returning (features, labels) raises error. + + serving_input_fn must return InputFnOps including a default input + alternative. + """ + input_features = constant_op.constant(["10"]) + input_ops = ({"features": input_features}, None) + input_alternatives, _ = ( + saved_model_export_utils.get_input_alternatives(input_ops)) + output_1 = constant_op.constant(["1"]) + output_2 = constant_op.constant(["2"]) + output_3 = constant_op.constant(["3"]) + provided_output_alternatives = { + "head-1": (constants.ProblemType.LINEAR_REGRESSION, { + "some_output_1": output_1 + }), + "head-2": (constants.ProblemType.CLASSIFICATION, { + "some_output_2": output_2 + }), + "head-3": (constants.ProblemType.UNSPECIFIED, { + "some_output_3": output_3 + }), + } + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions={"some_output": constant_op.constant(["4"])}, + output_alternatives=provided_output_alternatives) + output_alternatives, _ = (saved_model_export_utils.get_output_alternatives( + model_fn_ops, "head-1")) + + with self.assertRaisesRegexp( + ValueError, "A default input_alternative must be provided"): + saved_model_export_utils.build_all_signature_defs( + input_alternatives, output_alternatives, "head-1") + def test_get_timestamped_export_dir(self): export_dir_base = tempfile.mkdtemp() + "export/" export_dir_1 = saved_model_export_utils.get_timestamped_export_dir( @@ -227,6 +281,19 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertTrue(gfile.Exists(export_dir_3)) self.assertTrue(gfile.Exists(export_dir_4)) + def test_make_export_strategy(self): + """Only tests that an ExportStrategy instance is created.""" + def _export_input_fn(): + return array_ops.constant([1]), None + export_strategy = saved_model_export_utils.make_export_strategy( + export_input_fn=_export_input_fn, + default_output_alternative_key="default", + assets_extra={"from/path": "to/path"}, + as_text=False, + exports_to_keep=5) + self.assertTrue( + isinstance(export_strategy, export_strategy_lib.ExportStrategy)) + def _create_test_export_dir(export_dir_base): export_dir = saved_model_export_utils.get_timestamped_export_dir( diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 661cab13dae..f4e1c6d7197 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -31,6 +31,7 @@ Subclasses of `LinearOperator` provide a access to common methods on a @@LinearOperatorDiag @@LinearOperatorIdentity +@@LinearOperatorScaledIdentity @@LinearOperatorMatrix @@LinearOperatorTriL diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_identity_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_identity_test.py index 20e076749c0..36a255f3d50 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_identity_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_identity_test.py @@ -33,7 +33,7 @@ random_seed.set_random_seed(23) rng = np.random.RandomState(2016) -class LinearOperatorIdentitytest( +class LinearOperatorIdentityTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" @@ -251,5 +251,184 @@ class LinearOperatorIdentitytest( ) +class LinearOperatorScaledIdentityTest( + linear_operator_test_util.SquareLinearOperatorDerivedClassTest): + """Most tests done in the base class LinearOperatorDerivedClassTest.""" + + @property + def _dtypes_to_test(self): + # TODO(langmore) Test tf.float16 once tf.matrix_solve works in + # 16bit. + return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + + def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + shape = list(shape) + assert shape[-1] == shape[-2] + + batch_shape = shape[:-2] + num_rows = shape[-1] + + # Uniform values that are at least length 1 from the origin. Allows the + # operator to be well conditioned. + # Shape batch_shape + multiplier = linear_operator_test_util.random_sign_uniform( + shape=batch_shape, minval=1., maxval=2., dtype=dtype) + + operator = linalg_lib.LinearOperatorScaledIdentity(num_rows, multiplier) + + # Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args. + if use_placeholder: + multiplier_ph = array_ops.placeholder(dtype=dtype) + multiplier = multiplier.eval() + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows, multiplier_ph) + feed_dict = {multiplier_ph: multiplier} + else: + feed_dict = None + + multiplier_matrix = array_ops.expand_dims( + array_ops.expand_dims(multiplier, -1), -1) + mat = multiplier_matrix * linalg_ops.eye( + num_rows, batch_shape=batch_shape, dtype=dtype) + + return operator, mat, feed_dict + + def test_assert_positive_definite_does_not_raise_when_positive(self): + with self.test_session(): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=1.) + operator.assert_positive_definite().run() # Should not fail + + def test_assert_positive_definite_raises_when_negative(self): + with self.test_session(): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=-1.) + with self.assertRaisesOpError("not positive definite"): + operator.assert_positive_definite().run() + + def test_assert_non_singular_does_not_raise_when_non_singular(self): + with self.test_session(): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=[1., 2., 3.]) + operator.assert_non_singular().run() # Should not fail + + def test_assert_non_singular_raises_when_singular(self): + with self.test_session(): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=[1., 2., 0.]) + with self.assertRaisesOpError("was singular"): + operator.assert_non_singular().run() + + def test_assert_self_adjoint_does_not_raise_when_self_adjoint(self): + with self.test_session(): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=[1. + 0J]) + operator.assert_self_adjoint().run() # Should not fail + + def test_assert_self_adjoint_raises_when_not_self_adjoint(self): + with self.test_session(): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=[1. + 1J]) + with self.assertRaisesOpError("not self-adjoint"): + operator.assert_self_adjoint().run() + + def test_float16_apply(self): + # float16 cannot be tested by base test class because tf.matrix_solve does + # not work with float16. + with self.test_session(): + multiplier = rng.rand(3).astype(np.float16) + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=multiplier) + x = rng.randn(2, 3).astype(np.float16) + y = operator.apply(x) + self.assertAllClose(multiplier[..., None, None] * x, y.eval()) + + def test_non_scalar_num_rows_raises_static(self): + # Many "test_...num_rows" tests are performed in LinearOperatorIdentity. + with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"): + linalg_lib.LinearOperatorScaledIdentity( + num_rows=[2], multiplier=123.) + + def test_wrong_matrix_dimensions_raises_static(self): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=2.2) + x = rng.randn(3, 3).astype(np.float32) + with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): + operator.apply(x) + + def test_wrong_matrix_dimensions_raises_dynamic(self): + num_rows = array_ops.placeholder(dtypes.int32) + x = array_ops.placeholder(dtypes.float32) + + with self.test_session(): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows, multiplier=[1., 2], assert_proper_shapes=True) + y = operator.apply(x) + with self.assertRaisesOpError("Incompatible.*dimensions"): + y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)}) + + def test_broadcast_apply_and_solve(self): + # These cannot be done in the automated (base test class) tests since they + # test shapes that tf.batch_matmul cannot handle. + # In particular, tf.batch_matmul does not broadcast. + with self.test_session() as sess: + # Given this x and LinearOperatorScaledIdentity shape of (2, 1, 3, 3), the + # broadcast shape of operator and 'x' is (2, 2, 3, 4) + x = random_ops.random_normal(shape=(1, 2, 3, 4)) + + # operator is 2.2 * identity (with a batch shape). + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=3, multiplier=2.2 * array_ops.ones((2, 1))) + + # Batch matrix of zeros with the broadcast shape of x and operator. + zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype) + + # Test apply + expected = x * 2.2 + zeros + operator_apply = operator.apply(x) + self.assertAllEqual(operator_apply.get_shape(), expected.get_shape()) + self.assertAllClose(*sess.run([operator_apply, expected])) + + # Test solve + expected = x / 2.2 + zeros + operator_solve = operator.solve(x) + self.assertAllEqual(operator_solve.get_shape(), expected.get_shape()) + self.assertAllClose(*sess.run([operator_solve, expected])) + + def test_broadcast_apply_and_solve_scalar_scale_multiplier(self): + # These cannot be done in the automated (base test class) tests since they + # test shapes that tf.batch_matmul cannot handle. + # In particular, tf.batch_matmul does not broadcast. + with self.test_session() as sess: + # Given this x and LinearOperatorScaledIdentity shape of (3, 3), the + # broadcast shape of operator and 'x' is (1, 2, 3, 4), which is the same + # shape as x. + x = random_ops.random_normal(shape=(1, 2, 3, 4)) + + # operator is 2.2 * identity (with a batch shape). + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=3, multiplier=2.2) + + # Test apply + expected = x * 2.2 + operator_apply = operator.apply(x) + self.assertAllEqual(operator_apply.get_shape(), expected.get_shape()) + self.assertAllClose(*sess.run([operator_apply, expected])) + + # Test solve + expected = x / 2.2 + operator_solve = operator.solve(x) + self.assertAllEqual(operator_solve.get_shape(), expected.get_shape()) + self.assertAllClose(*sess.run([operator_solve, expected])) + + def test_is_x_flags(self): + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=1., + is_positive_definite=False, is_non_singular=True) + self.assertFalse(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertTrue(operator.is_self_adjoint is None) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py index 1b1c7fb3978..4eac01092f1 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib import linalg as linalg_lib from tensorflow.contrib.linalg.python.ops import linear_operator_util from tensorflow.python.framework import ops @@ -26,6 +28,7 @@ from tensorflow.python.platform import test linalg = linalg_lib random_seed.set_random_seed(23) +rng = np.random.RandomState(0) class AssertZeroImagPartTest(test.TestCase): @@ -88,5 +91,33 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase): z, message="ABC123").run() +class DomainDimensionStubOperator(object): + + def __init__(self, domain_dimension): + self._domain_dimension = ops.convert_to_tensor(domain_dimension) + + def domain_dimension_dynamic(self): + return self._domain_dimension + + +class AssertCompatibleMatrixDimensionsTest(test.TestCase): + + def test_compatible_dimensions_do_not_raise(self): + with self.test_session(): + x = ops.convert_to_tensor(rng.rand(2, 3, 4)) + operator = DomainDimensionStubOperator(3) + # Should not raise + linear_operator_util.assert_compatible_matrix_dimensions( + operator, x).run() + + def test_incompatible_dimensions_raise(self): + with self.test_session(): + x = ops.convert_to_tensor(rng.rand(2, 4, 4)) + operator = DomainDimensionStubOperator(3) + with self.assertRaisesOpError("Incompatible matrix dimensions"): + linear_operator_util.assert_compatible_matrix_dimensions( + operator, x).run() + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py index 8ea3894bdab..3304698ec67 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py @@ -29,12 +29,52 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops -__all__ = ["LinearOperatorIdentity",] +__all__ = [ + "LinearOperatorIdentity", + "LinearOperatorScaledIdentity", +] -class LinearOperatorIdentity(linear_operator.LinearOperator): +class BaseLinearOperatorIdentity(linear_operator.LinearOperator): + + def _check_num_rows_possibly_add_asserts(self): + """Static check of init arg `num_rows`, possibly add asserts.""" + # Possibly add asserts. + if self._assert_proper_shapes: + self._num_rows = control_flow_ops.with_dependencies( + [ + check_ops.assert_rank( + self._num_rows, + 0, + message="Argument num_rows must be a 0-D Tensor."), + check_ops.assert_non_negative( + self._num_rows, + message="Argument num_rows must be non-negative."), + ], + self._num_rows) + + # Static checks. + if not self._num_rows.dtype.is_integer: + raise TypeError("Argument num_rows must be integer type. Found:" + " %s" % self._num_rows) + + num_rows_static = self._num_rows_static + + if num_rows_static is None: + return # Cannot do any other static checks. + + if num_rows_static.ndim != 0: + raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" + " %s" % num_rows_static) + + if num_rows_static < 0: + raise ValueError("Argument num_rows must be non-negative. Found:" + " %s" % num_rows_static) + + +class LinearOperatorIdentity(BaseLinearOperatorIdentity): """`LinearOperator` acting like a [batch] square identity matrix. This operator acts like a [batch] identity matrix `A` with shape @@ -273,12 +313,10 @@ class LinearOperatorIdentity(linear_operator.LinearOperator): def _apply(self, x, adjoint=False): # Note that adjoint has no effect since this matrix is self-adjoint. - if x.dtype != self.dtype: - raise TypeError( - "Expected argument 'x' to have dtype %s. Found: %s" - % (self.dtype, x)) if self._assert_proper_shapes: - x = self._assert_compatible_matrix_dimensions(x) + aps = linear_operator_util.assert_compatible_matrix_dimensions( + self, x) + x = control_flow_ops.with_dependencies([aps], x) return self._possibly_broadcast_batch_shape(x) def _determinant(self): @@ -290,12 +328,6 @@ class LinearOperatorIdentity(linear_operator.LinearOperator): def _solve(self, rhs, adjoint=False): return self._apply(rhs) - def _to_dense(self): - return linalg_ops.eye( - num_rows=self.domain_dimension_dynamic(), - batch_shape=self.batch_shape_dynamic(), - dtype=self.dtype) - def add_to_tensor(self, mat, name="add_to_tensor"): """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. @@ -381,14 +413,238 @@ class LinearOperatorIdentity(linear_operator.LinearOperator): raise ValueError("Argument batch_shape must be non-negative. Found:" "%s" % self._batch_shape_static) - def _assert_compatible_matrix_dimensions(self, x): - """Check that an argument to solve/apply has proper domain dimension.""" - # Static checks are done in the base class. Only dynamic asserts here. - assert_same_dd = check_ops.assert_equal( - array_ops.shape(x)[-2], - self.domain_dimension_dynamic(), - message=( - "Incompatible matrix dimensions. " - "shape[-2] of argument to be the same as this operator")) - return control_flow_ops.with_dependencies([assert_same_dd], x) +class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): + """`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. + + This operator acts like a scaled [batch] identity matrix `A` with shape + `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a + batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is + a scaled version of the `N x N` identity matrix. + + `LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier` + (a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the + `multiplier` determines the scale for each batch member. + + ```python + # Create a 2 x 2 scaled identity matrix. + operator = LinearOperatorIdentity(num_rows=2, multiplier=3.) + + operator.to_dense() + ==> [[3., 0.] + [0., 3.]] + + operator.shape + ==> [2, 2] + + operator.log_determinant() + ==> 2 * Log[3] + + x = ... Shape [2, 4] Tensor + operator.apply(x) + ==> 3 * x + + y = tf.random_normal(shape=[3, 2, 4]) + # Note that y.shape is compatible with operator.shape because operator.shape + # is broadcast to [3, 2, 2]. + x = operator.solve(y) + ==> 3 * x + + # Create a 2-batch of 2x2 identity matrices + operator = LinearOperatorIdentity(num_rows=2, multiplier=5.) + operator.to_dense() + ==> [[[5., 0.] + [0., 5.]], + [[5., 0.] + [0., 5.]]] + + x = ... Shape [2, 2, 3] + operator.apply(x) + ==> 5 * x + + # Here the operator and x have different batch_shape, and are broadcast. + x = ... Shape [1, 2, 3] + operator.apply(x) + ==> 5 * x + ``` + + ### Shape compatibility + + This operator acts on [batch] matrix with compatible shape. + `x` is a batch matrix with compatible shape for `apply` and `solve` if + + ``` + operator.shape = [B1,...,Bb] + [N, N], with b >= 0 + x.shape = [C1,...,Cc] + [N, R], + and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] + ``` + + ### Performance + + * `operator.apply(x)` is `O(D1*...*Dd*N*R)` + * `operator.solve(x)` is `O(D1*...*Dd*N*R)` + * `operator.determinant()` is `O(D1*...*Dd)` + + #### Matrix property hints + + This `LinearOperator` is initialized with boolean flags of the form `is_X`, + for `X = non_singular, self_adjoint, positive_definite`. + These have the following meaning + * If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. + * If `is_X == False`, callers should expect the operator to not have `X`. + * If `is_X == None` (the default), callers should have no expectation either + way. + """ + + def __init__(self, + num_rows, + multiplier, + is_non_singular=None, + is_self_adjoint=None, + is_positive_definite=None, + assert_proper_shapes=False, + name="LinearOperatorScaledIdentity"): + """Initialize a `LinearOperatorScaledIdentity`. + + The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which + determines the size of each identity matrix, and a `multiplier`, + which defines `dtype`, batch shape, and scale of each matrix. + + This operator is able to broadcast the leading (batch) dimensions. + + Args: + num_rows: Scalar non-negative integer `Tensor`. Number of rows in the + corresponding identity matrix. + multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). + is_non_singular: Expect that this operator is non-singular. + is_self_adjoint: Expect that this operator is equal to its hermitian + transpose. + is_positive_definite: Expect that this operator is positive definite. + assert_proper_shapes: Python `bool`. If `False`, only perform static + checks that initialization and method arguments have proper shape. + If `True`, and static checks are inconclusive, add asserts to the graph. + name: A name for this `LinearOperator` + + Raises: + ValueError: If `num_rows` is determined statically to be non-scalar, or + negative. + """ + self._assert_proper_shapes = assert_proper_shapes + + with ops.name_scope(name, values=[multiplier, num_rows]): + self._multiplier = ops.convert_to_tensor(multiplier, name="multiplier") + + super(LinearOperatorScaledIdentity, self).__init__( + dtype=self._multiplier.dtype, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + name=name) + + # Shape [B1,...Bb, 1, 1] + self._multiplier_matrix = array_ops.expand_dims( + array_ops.expand_dims(self.multiplier, -1), -1) + self._multiplier_matrix_conj = math_ops.conj( + self._multiplier_matrix) + self._abs_multiplier = math_ops.abs(self.multiplier) + + self._num_rows = linear_operator_util.shape_tensor( + num_rows, name="num_rows") + self._num_rows_static = tensor_util.constant_value(self._num_rows) + self._check_num_rows_possibly_add_asserts() + self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype) + self._num_rows_cast_to_real_dtype = math_ops.cast( + self._num_rows, self.dtype.real_dtype) + + def _shape(self): + matrix_shape = tensor_shape.TensorShape( + (self._num_rows_static, self._num_rows_static)) + + batch_shape = self.multiplier.get_shape() + return batch_shape.concatenate(matrix_shape) + + def _shape_dynamic(self): + matrix_shape = array_ops.stack( + (self._num_rows, self._num_rows), axis=0) + + batch_shape = array_ops.shape(self.multiplier) + return array_ops.concat((batch_shape, matrix_shape), 0) + + def _assert_non_singular(self): + return check_ops.assert_positive( + math_ops.abs(self.multiplier), + message="LinearOperator was singular") + + def _assert_positive_definite(self): + return check_ops.assert_positive( + math_ops.real(self.multiplier), + message="LinearOperator was not positive definite.") + + def _assert_self_adjoint(self): + imag_multiplier = math_ops.imag(self.multiplier) + return check_ops.assert_equal( + array_ops.zeros_like(imag_multiplier), + imag_multiplier, + message="LinearOperator was not self-adjoint") + + def _apply(self, x, adjoint=False): + if adjoint: + matrix = self._multiplier_matrix_conj + else: + matrix = self._multiplier_matrix + if self._assert_proper_shapes: + aps = linear_operator_util.assert_compatible_matrix_dimensions( + self, x) + x = control_flow_ops.with_dependencies([aps], x) + return x * matrix + + def _determinant(self): + return self.multiplier ** self._num_rows_cast_to_dtype + + def _log_abs_determinant(self): + return self._num_rows_cast_to_real_dtype * math_ops.log( + self._abs_multiplier) + + def _solve(self, rhs, adjoint=False): + if adjoint: + matrix = self._multiplier_matrix_conj + else: + matrix = self._multiplier_matrix + if self._assert_proper_shapes: + aps = linear_operator_util.assert_compatible_matrix_dimensions( + self, rhs) + rhs = control_flow_ops.with_dependencies([aps], rhs) + return rhs / matrix + + def add_to_tensor(self, mat, name="add_to_tensor"): + """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. + + Args: + mat: `Tensor` with same `dtype` and shape broadcastable to `self`. + name: A name to give this `Op`. + + Returns: + A `Tensor` with broadcast shape and same `dtype` as `self`. + """ + with self._name_scope(name, values=[mat]): + # Shape [B1,...,Bb, 1] + multiplier_vector = array_ops.expand_dims(self.multiplier, -1) + + # Shape [C1,...,Cc, M, M] + mat = ops.convert_to_tensor(mat, name="mat") + + # Shape [C1,...,Cc, M] + mat_diag = array_ops.matrix_diag_part(mat) + + # multiplier_vector broadcasts here. + new_diag = multiplier_vector + mat_diag + + return array_ops.matrix_set_diag(mat, new_diag) + + @property + def multiplier(self): + """The [batch] scalar `Tensor`, `c` in `cI`.""" + return self._multiplier diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py index 09503ec12fb..44092f0c062 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -67,6 +68,32 @@ def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"): return check_ops.assert_equal(zero, math_ops.imag(x), message=message) +def assert_compatible_matrix_dimensions(operator, x): + """Assert that an argument to solve/apply has proper domain dimension. + + If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then + `operator.apply(x)` is defined only if `N = Q`. This `Op` returns an + `Assert` that "fires" if this is not the case. Static checks are already + done by the base class `LinearOperator`. + + Args: + operator: `LinearOperator`. + x: `Tensor`. + + Returns: + `Assert` `Op`. + """ + # Static checks are done in the base class. Only dynamic asserts here. + assert_same_dd = check_ops.assert_equal( + array_ops.shape(x)[-2], + operator.domain_dimension_dynamic(), + message=( + "Incompatible matrix dimensions. " + "shape[-2] of argument to be the same as this operator")) + + return assert_same_dd + + def shape_tensor(shape, name=None): """Convert Tensor using default type, unless empty list or tuple.""" # Works just like random_ops._ShapeTensor. diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index d1e53f4a663..5eebe06cf78 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -795,7 +795,7 @@ def string_to_index_table_from_file(vocabulary_file=None, The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets]`. The underlying table must be initialized by calling - `tf.initialize_all_tables.run()` or `table.init.run()` once. + `tf.tables_initializer.run()` or `table.init.run()` once. Sample Usages: @@ -813,7 +813,7 @@ def string_to_index_table_from_file(vocabulary_file=None, vocabulary_file="test.txt", num_oov_buckets=1) ids = table.lookup(features) ... - tf.initialize_all_tables().run() + tf.tables_initializer().run() ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket ``` @@ -893,7 +893,7 @@ def string_to_index_table_from_tensor(mapping, The bucket ID range is `[mapping size, mapping size + num_oov_buckets]`. The underlying table must be initialized by calling - `tf.initialize_all_tables.run()` or `table.init.run()` once. + `tf.tables_initializer.run()` or `table.init.run()` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -907,7 +907,7 @@ def string_to_index_table_from_tensor(mapping, features = tf.constant(["emerson", "lake", "and", "palmer"]) ids = table.lookup(features) ... - tf.initialize_all_tables().run() + tf.tables_initializer().run() ids.eval() ==> [0, 1, 4, 2] ``` @@ -975,7 +975,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): will throw a FailedPreconditionError. The underlying table must be initialized by calling - `tf.initialize_all_tables.run()` once. + `tf.tables_initializer.run()` once. For example: @@ -985,7 +985,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): ids = tf.contrib.lookup.string_to_index( feats, mapping=mapping_strings, default_value=-1) ... - tf.initialize_all_tables().run() + tf.tables_initializer().run() ids.eval() ==> [0, 1, -1, 2] ``` @@ -1022,7 +1022,7 @@ def index_to_string_table_from_file(vocabulary_file, (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.initialize_all_tables.run()` or `table.init.run()` once. + `tf.tables_initializer.run()` or `table.init.run()` once. Sample Usages: @@ -1040,7 +1040,7 @@ def index_to_string_table_from_file(vocabulary_file, vocabulary_file="test.txt", default_value="UNKNOWN") values = table.lookup(indices) ... - tf.initialize_all_tables().run() + tf.tables_initializer().run() values.eval() ==> ["lake", "UNKNOWN"] ``` @@ -1096,7 +1096,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.initialize_all_tables.run()` or `table.init.run()` once. + `tf.tables_initializer.run()` or `table.init.run()` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -1110,7 +1110,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): mapping_string, default_value="UNKNOWN") values = table.lookup(indices) ... - tf.initialize_all_tables().run() + tf.tables_initializer().run() values.eval() ==> ["lake", "UNKNOWN"] ``` @@ -1159,7 +1159,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.initialize_all_tables.run()` once. + `tf.tables_initializer.run()` once. For example: @@ -1169,7 +1169,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): values = tf.contrib.lookup.index_to_string( indices, mapping=mapping_string, default_value="UNKNOWN") ... - tf.initialize_all_tables().run() + tf.tables_initializer().run() values.eval() ==> ["lake", "UNKNOWN"] ``` diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 1ec6a231c8a..15c318b6ef7 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase): table3 = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual(3, table1.size().eval()) self.assertAllEqual(3, table2.size().eval()) self.assertAllEqual(3, table3.size().eval()) @@ -1148,7 +1148,7 @@ class StringToIndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_string_to_index_table_from_file_with_default_value(self): @@ -1160,7 +1160,7 @@ class StringToIndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) def test_string_to_index_table_from_file_with_oov_buckets(self): @@ -1172,7 +1172,7 @@ class StringToIndexTableFromFile(test.TestCase): constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual( ( 1, # From vocabulary file. @@ -1195,7 +1195,7 @@ class StringToIndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((1, -1, -1), ids.eval()) self.assertEqual(2, table.size().eval()) @@ -1222,7 +1222,7 @@ class StringToIndexTableFromFile(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), ids.eval()) self.assertEqual(3, table.size().eval()) @@ -1255,7 +1255,7 @@ class StringToIndexTableFromTensor(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) def test_string_to_index_table_from_tensor_with_default_value(self): @@ -1266,7 +1266,7 @@ class StringToIndexTableFromTensor(test.TestCase): ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) self.assertRaises(errors_impl.OpError, ids.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), ids.eval()) def test_string_to_index_table_from_tensor_with_only_oov_buckets(self): @@ -1301,7 +1301,7 @@ class StringToIndexTest(test.TestCase): indices = lookup_ops.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, indices.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, -1), indices.eval()) @@ -1312,7 +1312,7 @@ class StringToIndexTest(test.TestCase): indices = lookup_ops.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, - data_flow_ops.initialize_all_tables().run) + data_flow_ops.tables_initializer().run) def test_string_to_index_with_default_value(self): default_value = -42 @@ -1323,7 +1323,7 @@ class StringToIndexTest(test.TestCase): feats, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, indices.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, default_value), indices.eval()) @@ -1342,7 +1342,7 @@ class IndexToStringTableFromFileTest(test.TestCase): vocabulary_file=vocabulary_file) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -1354,7 +1354,7 @@ class IndexToStringTableFromFileTest(test.TestCase): vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -1368,7 +1368,7 @@ class IndexToStringTableFromFileTest(test.TestCase): default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"salad", default_value, default_value), features.eval()) @@ -1380,7 +1380,7 @@ class IndexToStringTableFromFileTest(test.TestCase): features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - init = data_flow_ops.initialize_all_tables() + init = data_flow_ops.tables_initializer() self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Invalid vocab_size", init.run) @@ -1392,7 +1392,7 @@ class IndexToStringTableFromFileTest(test.TestCase): features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval()) @@ -1407,7 +1407,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) features = table.lookup(indices) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) @@ -1419,7 +1419,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): mapping=mapping_strings) indices = constant_op.constant([0, 1, 4], dtypes.int64) features = table.lookup(indices) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval()) def test_index_to_string_with_default_value(self): @@ -1432,7 +1432,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): features = table.lookup(indices) self.assertRaises(errors_impl.OpError, features.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), features.eval()) @@ -1446,7 +1446,7 @@ class IndexToStringTest(test.TestCase): feats = lookup_ops.index_to_string(indices, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, feats.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), feats.eval()) @@ -1456,11 +1456,11 @@ class IndexToStringTest(test.TestCase): mapping_strings = constant_op.constant(["hello", "hello"]) indices = constant_op.constant([0, 1, 4], dtypes.int64) feats = lookup_ops.index_to_string(indices, mapping=mapping_strings) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval()) self.assertRaises(errors_impl.OpError, - data_flow_ops.initialize_all_tables().run) + data_flow_ops.tables_initializer().run) def test_index_to_string_with_default_value(self): default_value = b"NONE" @@ -1471,7 +1471,7 @@ class IndexToStringTest(test.TestCase): indices, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, feats.eval) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval()) @@ -1615,7 +1615,7 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value, shared_name=shared_name) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -1847,7 +1847,7 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup_ops.StrongHashSpec((1, 2)), name="table2") - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() input_string = constant_op.constant( ["fruit", "brain", "salad", "surgery", "UNK"]) @@ -1933,7 +1933,7 @@ class IdTableWithHashBucketsTest(test.TestCase): default_value2), oov_buckets) - data_flow_ops.initialize_all_tables().run() + data_flow_ops.tables_initializer().run() input_string_1 = constant_op.constant( ["brain", "salad", "surgery", "UNK"]) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 0d107b646c7..284a5894cf2 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -137,7 +137,7 @@ $(shell mkdir -p $(DEPDIR) >/dev/null) # Settings for the target compiler. CXX := $(CC_PREFIX) gcc -OPTFLAGS := -O2 +OPTFLAGS := -O2 -march=native CXXFLAGS := --std=c++11 -DIS_SLIM_BUILD -fno-exceptions -DNDEBUG $(OPTFLAGS) LDFLAGS := \ -L/usr/local/lib diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc index fd0ea9a2fb7..9c7cdf192da 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim.cc @@ -197,14 +197,13 @@ Status ConvertNamedSignaturesToSignatureDef(const Signatures& signatures, continue; } } - SignatureDef signature_def; const Signature signature = it_named_signature.second; if (IsRegressionSignature(signature)) { (*meta_graph_def->mutable_signature_def())[key] = BuildRegressionSignatureDef(signature.regression_signature()); } else if (IsClassificationSignature(signature)) { - (*meta_graph_def->mutable_signature_def())[key] = signature_def; - BuildClassificationSignatureDef(signature.classification_signature()); + (*meta_graph_def->mutable_signature_def())[key] = + BuildClassificationSignatureDef(signature.classification_signature()); } else { LOG(WARNING) << "Named signature up-conversion to SignatureDef is only supported " diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc index fb367beb0f9..dd318d08ace 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc @@ -213,18 +213,32 @@ TEST(BundleShimTest, DefaultSignatureGeneric) { EXPECT_EQ(0, meta_graph_def.signature_def_size()); } +// Helper function to validate that the SignatureDef found in the MetaGraphDef +// with the provided key has the expected string representation. +void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key, + const string& expected_string_signature_def) { + tensorflow::SignatureDef expected_signature; + CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def, + &expected_signature)); + auto iter = meta_graph_def.signature_def().find(key); + ASSERT_TRUE(iter != meta_graph_def.signature_def().end()); + EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString()); +} + TEST(BundleShimTest, NamedRegressionSignatures) { Signatures signatures; - RegressionSignature* inputs_regression_signature = - (*signatures.mutable_named_signatures())[kRegressInputs] + RegressionSignature* foo_regression_signature = + (*signatures.mutable_named_signatures())["foo"] .mutable_regression_signature(); - inputs_regression_signature->mutable_input()->set_tensor_name("foo-input"); + foo_regression_signature->mutable_input()->set_tensor_name("foo-input"); + foo_regression_signature->mutable_output()->set_tensor_name("foo-output"); - RegressionSignature* outputs_regression_signature = - (*signatures.mutable_named_signatures())[kRegressOutputs] + RegressionSignature* bar_regression_signature = + (*signatures.mutable_named_signatures())["bar"] .mutable_regression_signature(); - outputs_regression_signature->mutable_output()->set_tensor_name("foo-output"); + bar_regression_signature->mutable_input()->set_tensor_name("bar-input"); + bar_regression_signature->mutable_output()->set_tensor_name("bar-output"); MetaGraphDef meta_graph_def; (*meta_graph_def.mutable_collection_def())[kSignaturesKey] @@ -232,7 +246,36 @@ TEST(BundleShimTest, NamedRegressionSignatures) { ->add_value() ->PackFrom(signatures); ConvertSignaturesToSignatureDefs(&meta_graph_def); - EXPECT_EQ(2, meta_graph_def.signature_def_size()); + ASSERT_EQ(2, meta_graph_def.signature_def_size()); + + ValidateSignatureDef(meta_graph_def, "foo", + "inputs { " + " key: \"inputs\" " + " value { " + "name: \"foo-input\" " + " } " + "} " + "outputs { " + " key: \"outputs\" " + " value { " + " name: \"foo-output\" " + " } " + "} " + "method_name: \"tensorflow/serving/regress\" "); + ValidateSignatureDef(meta_graph_def, "bar", + "inputs { " + " key: \"inputs\" " + " value { " + "name: \"bar-input\" " + " } " + "} " + "outputs { " + " key: \"outputs\" " + " value { " + " name: \"bar-output\" " + " } " + "} " + "method_name: \"tensorflow/serving/regress\" "); } TEST(BundleShimTest, NamedClassificationSignatures) { @@ -257,7 +300,36 @@ TEST(BundleShimTest, NamedClassificationSignatures) { ->add_value() ->PackFrom(signatures); ConvertSignaturesToSignatureDefs(&meta_graph_def); - EXPECT_EQ(2, meta_graph_def.signature_def_size()); + ASSERT_EQ(2, meta_graph_def.signature_def_size()); + + ValidateSignatureDef(meta_graph_def, "foo", + "inputs { " + " key: \"inputs\" " + " value { " + "name: \"foo-input\" " + " } " + "} " + "outputs { " + " key: \"classes\" " + " value { " + " name: \"foo-classes\" " + " } " + "} " + "method_name: \"tensorflow/serving/classify\" "); + ValidateSignatureDef(meta_graph_def, "bar", + "inputs { " + " key: \"inputs\" " + " value { " + "name: \"bar-input\" " + " } " + "} " + "outputs { " + " key: \"scores\" " + " value { " + " name: \"bar-scores\" " + " } " + "} " + "method_name: \"tensorflow/serving/classify\" "); } // Checks the Predict SignatureDef created when the named signatures have diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 415b7665127..b87342873e4 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -627,7 +627,7 @@ def train(train_op, init_feed_dict: A feed dictionary to use when executing the `init_op`. local_init_op: The local initialization operation. If left to its default value, then the session is initialized by calling - `tf.local_variables_initializer()` and `tf.initialize_all_tables()`. + `tf.local_variables_initializer()` and `tf.tables_initializer()`. init_fn: An optional callable to be executed after `init_op` is called. The callable must accept one argument, the session being initialized. ready_op: Operation to check if the model is ready to use. If left to its @@ -697,7 +697,7 @@ def train(train_op, if local_init_op == _USE_DEFAULT: local_init_op = control_flow_ops.group( tf_variables.local_variables_initializer(), - data_flow_ops.initialize_all_tables()) + data_flow_ops.tables_initializer()) if sync_optimizer is not None and isinstance( sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index 69a4ba36e17..1a2ed1f5b3c 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -115,8 +115,10 @@ def bucket(tensors, tensors: The list or dictionary of tensors, representing a single element, to bucket. Nested lists are not supported. which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`. - batch_size: The new batch size pulled from the queue - (python int or int32 scalar). + batch_size: The new batch size pulled from the queue (all queues will have + the same size). If a list is passed in then each bucket will have a + different batch_size. + (python int, int32 scalar or iterable of integers of length num_buckets). num_buckets: A python integer, the number of buckets. num_threads: An integer. The number of threads enqueuing `tensors`. capacity: An integer. The maximum number of minibatches in the top queue, @@ -145,8 +147,17 @@ def bucket(tensors, Raises: ValueError: If the `shapes` are not specified, and cannot be - inferred from the elements of `tensors`. + inferred from the elements of `tensors` or if batch_size is a sequence + but it's length != num_buckets. """ + batch_size_per_bucket = False + if isinstance(batch_size, (list, tuple)): + batch_size_per_bucket = True + if len(batch_size) != num_buckets: + raise ValueError( + "If batch_size is a list it must have num_buckets elements") + else: + batch_size = [batch_size] * num_buckets tensor_list = _as_tensor_list(tensors) with ops.name_scope(name, "bucket", tensor_list) as name: tensor_list = _validate_bucket(tensor_list) @@ -154,11 +165,12 @@ def bucket(tensors, tensor_list, enqueue_many=False, keep_input=constant_op.constant(True)) # Round-trip batch_size to a tensor, and possibly back - batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int32, name="batch_size") - static_batch_size = tensor_util.constant_value(batch_size) - batch_size = (static_batch_size if static_batch_size is not None else - batch_size) + for i, bucket_batch_size in enumerate(batch_size): + bucket_batch_size = ops.convert_to_tensor( + bucket_batch_size, dtype=dtypes.int32, name="batch_size") + static_batch_size = tensor_util.constant_value(bucket_batch_size) + batch_size[i] = (static_batch_size if static_batch_size is not None else + bucket_batch_size) types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many=False) @@ -179,8 +191,9 @@ def bucket(tensors, shared_name=shared_name_i, name="bucket_queue_%d" % i)) - maybe_static_batch_size = (None if allow_smaller_final_batch else - static_batch_size) + maybe_static_batch_size = ( + None if (allow_smaller_final_batch or batch_size_per_bucket) + else static_batch_size) bucket_shapes = [ tensor_shape.vector(maybe_static_batch_size).concatenate(s) @@ -229,9 +242,9 @@ def bucket(tensors, enqueues_to_top = [ top_queue.enqueue( [constant_op.constant(i)] + which_dequeue(q)( - batch_size, name="read_bucket_%d" % i), + bs, name="read_bucket_%d" % i), name="enqueue_from_bucket_%d" % i) - for i, q in enumerate(bucket_queues) + for i, (q, bs) in enumerate(zip(bucket_queues, batch_size)) ] for i, q in enumerate(bucket_queues): @@ -284,8 +297,10 @@ def bucket_by_sequence_length(input_length, input_length: `int32` scalar `Tensor`, the sequence length of tensors. tensors: The list or dictionary of tensors, representing a single element, to bucket. Nested lists are not supported. - batch_size: The new batch size pulled from the queue - (python int or int32 scalar). + batch_size: The new batch size pulled from the queue (all queues will have + the same size). If a list is passed in then each bucket will have a + different batch_size. + (python int, int32 scalar or iterable of integers of length num_buckets). bucket_boundaries: int list, increasing non-negative numbers. The edges of the buckets to use when bucketing tensors. Two extra buckets are created, one for `input_length < bucket_boundaries[0]` and @@ -317,7 +332,8 @@ def bucket_by_sequence_length(input_length, Raises: TypeError: if `bucket_boundaries` is not a list of python integers. ValueError: if `bucket_boundaries` is empty or contains non-increasing - values. + values or if batch_size is a list and it's length doesn't equal the number + of buckets. """ tensor_list = _as_tensor_list(tensors) if not isinstance(bucket_boundaries, (list, tuple)): diff --git a/tensorflow/contrib/training/python/training/bucket_ops_test.py b/tensorflow/contrib/training/python/training/bucket_ops_test.py index afceb5d6881..ca7074256b8 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops_test.py +++ b/tensorflow/contrib/training/python/training/bucket_ops_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -139,6 +140,51 @@ class BucketTest(test.TestCase): self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort]) self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort]) + def testBatchSizePerBucket(self): + which_bucket = control_flow_ops.cond(self.scalar_int < 5, + lambda: constant_op.constant(0), + lambda: constant_op.constant(1)) + batch_sizes = [5, 10] + bucketed_dynamic = bucket_ops.bucket( + tensors=[self.scalar_int, self.unk_int64, self.vec3_str], + which_bucket=which_bucket, + num_buckets=2, + batch_size=batch_sizes, + num_threads=1, + dynamic_pad=True) + # Check shape inference on bucketing outputs + self.assertAllEqual( + [[None], [None, None], [None, 3]], + [out.get_shape().as_list() for out in bucketed_dynamic[1]]) + with self.test_session() as sess: + for v in range(15): + self.enqueue_inputs(sess, { + self.scalar_int_feed: v, + self.unk_int64_feed: v * [v], + self.vec3_str_feed: 3 * [str(v)] + }) + self.start_queue_runners(sess) + + # Get two minibatches (one with small values, one with large). + bucketed_values_0 = sess.run(bucketed_dynamic) + bucketed_values_1 = sess.run(bucketed_dynamic) + + # Figure out which output has the small values + if bucketed_values_0[0] < 5: + bucketed_values_large, bucketed_values_small = (bucketed_values_1, + bucketed_values_0) + else: + bucketed_values_small, bucketed_values_large = (bucketed_values_0, + bucketed_values_1) + + # Ensure bucket 0 was used for all minibatch entries. + self.assertAllEqual(0, bucketed_values_small[0]) + self.assertAllEqual(1, bucketed_values_large[0]) + + # Check that the batch sizes differ per bucket + self.assertEqual(5, len(bucketed_values_small[1][0])) + self.assertEqual(10, len(bucketed_values_large[1][0])) + def testEvenOddBuckets(self): which_bucket = (self.scalar_int % 2) bucketed_dynamic = bucket_ops.bucket( diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index 7243846105a..448fabc1ef3 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -284,10 +284,10 @@ def add_gradients_summaries(grads_and_vars): else: grad_values = grad summaries.append( - summary.histogram_summary(var.op.name + ':gradient', grad_values)) + summary.histogram(var.op.name + '_gradient', grad_values)) summaries.append( - summary.histogram_summary(var.op.name + ':gradient_norm', - clip_ops.global_norm([grad_values]))) + summary.histogram(var.op.name + '_gradient_norm', + clip_ops.global_norm([grad_values]))) else: logging.info('Var %s has no gradient', var.op.name) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index d3ffd692b28..c27cc488058 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -64,6 +64,7 @@ load( "//tensorflow:tensorflow.bzl", "if_android", "if_ios", + "if_not_mobile", "if_not_windows", "tf_copts", "tf_cc_test", @@ -100,6 +101,8 @@ load( "tf_additional_test_deps", "tf_additional_test_srcs", "tf_kernel_tests_linkstatic", + "tf_additional_cloud_op_deps", + "tf_additional_cloud_kernel_deps", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -454,6 +457,16 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "cloud_ops_op_lib", + srcs = ["ops/cloud_ops.cc"], + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [":framework"], + alwayslink = 1, +) + cc_library( name = "ops", visibility = ["//visibility:public"], @@ -484,7 +497,7 @@ cc_library( ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", - ], + ] + tf_additional_cloud_op_deps(), alwayslink = 1, ) @@ -606,7 +619,7 @@ cc_library( "//tensorflow/core/kernels:string", "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:word2vec_kernels", - ] + if_not_windows([ + ] + tf_additional_cloud_kernel_deps() + if_not_windows([ "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:array_not_windows", "//tensorflow/core/kernels:math_not_windows", diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 969f6a5b764..07c6bdd6831 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -111,10 +111,10 @@ class Device : public DeviceBase { // // 'library' provides access to the function library which is shared // between all device partitions. - // 'graphdef' supplies the partition of the graph assigned to this + // 'graph' supplies the partition of the graph assigned to this // device. virtual Status MaybeRewriteGraph(const FunctionDefLibrary& /*library*/, - GraphDef* /*graphdef*/) { + std::unique_ptr<Graph>* /*graph*/) { return Status::OK(); } diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index d41931eed8f..bb1ed726408 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -34,6 +34,7 @@ class DeviceAttributes; class DeviceMgr { public: + // Takes ownership of each device in 'devices'. // TODO(zhifengc): Other initialization information. explicit DeviceMgr(const std::vector<Device*>& devices); ~DeviceMgr(); diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 7dc6db682e7..85ce9d772a1 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/simple_placer.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/function.h" @@ -940,7 +941,7 @@ Status DirectSession::GetOrCreateExecutors( GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { const string& partition_name = iter->first; - Graph* partition_graph = iter->second.get(); + std::unique_ptr<Graph>& partition_graph = iter->second; const int graph_def_version = partition_graph->versions().producer(); Device* device; @@ -980,24 +981,23 @@ Status DirectSession::GetOrCreateExecutors( }; params.node_outputs_cb = node_outputs_callback_; - partition_graph = iter->second.release(); - optimizer.Optimize(lib, options_.env, device, &partition_graph); + optimizer.Optimize(lib, options_.env, device, &iter->second); // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph if (run_state_args->debugger_state) { TF_RETURN_IF_ERROR(run_state_args->debugger_state->DecorateGraphForDebug( - partition_graph, params.device)); + partition_graph.get(), params.device)); } - iter->second.reset(partition_graph); TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), - device->name(), partition_graph)); + device->name(), + partition_graph.get())); // NewLocalExecutor takes ownership of partition_graph. - item->graph = partition_graph; + item->graph = partition_graph.get(); item->executor = nullptr; Executor* executor; TF_RETURN_IF_ERROR( - NewLocalExecutor(params, iter->second.release(), &executor)); + NewLocalExecutor(params, partition_graph.release(), &executor)); item->executor.reset(executor); } @@ -1118,12 +1118,31 @@ Status DirectSession::CreateGraphs( } } - Status s; - for (auto&& partition : partitions) { - const string& partition_name = partition.first; + for (const auto& partition : partitions) { + std::unique_ptr<Graph> device_graph( + new Graph(client_graph->flib_def.get())); + GraphConstructorOptions device_opts; + // There are internal operations (e.g., send/recv) that we now allow. + device_opts.allow_internal_ops = true; + device_opts.expect_device_spec = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, + device_graph.get())); + outputs->emplace(partition.first, std::move(device_graph)); + } - GraphDef* graph_def = &partition.second; - VLOG(2) << "Created " << ProtoDebugString(*graph_def) << " for " + GraphOptimizationPassOptions optimization_options; + optimization_options.session_options = &options_; + optimization_options.flib_def = client_graph->flib_def.get(); + optimization_options.partition_graphs = outputs; + TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); + + Status s; + for (auto& partition : *outputs) { + const string& partition_name = partition.first; + std::unique_ptr<Graph>* graph = &partition.second; + + VLOG(2) << "Created " << DebugString(graph->get()) << " for " << partition_name; // Give the device an opportunity to rewrite its subgraph. @@ -1134,20 +1153,10 @@ Status DirectSession::CreateGraphs( // may be possible use cases where a device may want to modify // function definitions - in which case the library would need to be // replicated per device. - s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph_def); + s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph); if (!s.ok()) { break; } - std::unique_ptr<Graph> device_graph( - new Graph(client_graph->flib_def.get())); - GraphConstructorOptions device_opts; - // There are internal operations (e.g., send/recv) that we now - // allow. - device_opts.allow_internal_ops = true; - device_opts.expect_device_spec = true; - TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(device_opts, *graph_def, device_graph.get())); - outputs->emplace(partition_name, std::move(device_graph)); } *flib_def = std::move(client_graph->flib_def); return s; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 11e0b3a0421..cb7e1a40ceb 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -463,7 +463,7 @@ void DumpGraph(StringPiece label, const Graph* g) { } } -void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) { +void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) { OptimizerOptions opts; opts.set_do_common_subexpression_elimination(true); opts.set_do_function_inlining(true); @@ -475,16 +475,12 @@ void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) { Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { const FunctionBody* fbody = GetFunctionBody(handle); CHECK_NOTNULL(fbody); - Graph* g = new Graph(lib_def_); - CopyGraph(*fbody->graph, g); + std::unique_ptr<Graph> g(new Graph(lib_def_)); + CopyGraph(*fbody->graph, g.get()); optimizer_.Optimize(this, env(), device(), &g); - auto s = EnsureMemoryTypes(DeviceType(device()->device_type()), - device()->name(), g); - if (!s.ok()) { - delete g; - return Status::OK(); - } + TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()), + device()->name(), g.get())); // Creates an executor based on the g. This must be done without // holding mu_ because create_kernel_ calls back into the library. @@ -495,11 +491,12 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }; + Graph* graph = g.get(); Executor* exec; - TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &exec)); + TF_RETURN_IF_ERROR(NewLocalExecutor(params, g.release(), &exec)); *item = new Item; - (*item)->graph = g; + (*item)->graph = graph; (*item)->exec = exec; return Status::OK(); } diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index 3a560d0eafc..7cfe6946735 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_ #include <functional> +#include <memory> #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -126,7 +127,7 @@ void DumpGraph(StringPiece label, const Graph* g); // OptimizeGraph mutates **g extensively and replaces '*g' with a // complete copy. Therefore, the caller should not keep any references // to nodes *g. -void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g); +void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g); // Convert the Graph of a function to a GraphDef. // diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index b72d2bf12de..f86a8ed5dc0 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -342,9 +342,9 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g(GetFuncBody("XTimes16", {{"T", DT_FLOAT}})); ASSERT_TRUE(g != nullptr); - ExpandInlineFunctions(lib_, g); + ExpandInlineFunctions(lib_, g.get()); OptimizeGraph(lib_, &g); const char* e0 = R"P( (n2:float) -> (n7:float) { @@ -355,8 +355,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { n7 = Mul[T=float](n6, n8) } )P"; - EXPECT_EQ(e0, DebugString(g)); - delete g; + EXPECT_EQ(e0, DebugString(g.get())); } TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { @@ -380,15 +379,14 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { // Return {{"o", "g:output"}}); Init({test::function::Swap(), func}); - Graph* g = GetFuncBody("ManySwapsNodeDef", {}); + std::unique_ptr<Graph> g(GetFuncBody("ManySwapsNodeDef", {})); ASSERT_TRUE(g != nullptr); OptimizeGraph(lib_, &g); const char* e0 = R"P( (n3:float, n2:float) -> (n3:float) { } )P"; - EXPECT_EQ(e0, DebugString(g)); - delete g; + EXPECT_EQ(e0, DebugString(g.get())); } TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { @@ -414,7 +412,7 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}}, {{"o", "o:z:0"}}); Init({test::function::Swap(), func}); - Graph* g = GetFuncBody("ManySwapsFirst", {}); + std::unique_ptr<Graph> g(GetFuncBody("ManySwapsFirst", {})); ASSERT_TRUE(g != nullptr); OptimizeGraph(lib_, &g); @@ -431,8 +429,7 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { n6 = Add[T=float](n4, n5) } )P"; - EXPECT_EQ(e0, DebugString(g)); - delete g; + EXPECT_EQ(e0, DebugString(g.get())); } TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { @@ -489,7 +486,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { )P"; EXPECT_EQ(e0, DebugString(f)); delete f; - auto g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g(GetGradBody("XTimesTwo", {{"T", DT_FLOAT}})); const char* e1 = R"P( (n4:float, n6:float) -> (n7:float) { n2 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() @@ -498,7 +495,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { n7 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Mul[T=float]](n4, n3, n6) } )P"; - EXPECT_EQ(e1, DebugString(g)); + EXPECT_EQ(e1, DebugString(g.get())); OptimizeGraph(lib_, &g); const char* e2 = R"P( @@ -512,9 +509,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { n9 = Reshape[T=float, Tshape=int32](n8, n6) } )P"; - EXPECT_EQ(e2, DebugString(g)); - - delete g; + EXPECT_EQ(e2, DebugString(g.get())); } TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) { @@ -591,7 +586,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { Init({test, grad}); - Graph* g = GetFuncBody("TestGrad", {}); + std::unique_ptr<Graph> g(GetFuncBody("TestGrad", {})); ASSERT_TRUE(g != nullptr); const char* e0 = R"P( (n4:float, n3:float) -> (n8:float, n6:float) { @@ -601,9 +596,9 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n8 = Identity[T=float](n5) } )P"; - EXPECT_EQ(e0, DebugString(g)); + EXPECT_EQ(e0, DebugString(g.get())); - ExpandInlineFunctions(lib_, g); + ExpandInlineFunctions(lib_, g.get()); const char* e1 = R"P( (n4:float, n3:float) -> (n8:float, n6:float) { n10 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]() @@ -625,7 +620,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n8 = Identity[T=float](n27) } )P"; - EXPECT_EQ(e1, DebugString(g)); + EXPECT_EQ(e1, DebugString(g.get())); OptimizeGraph(lib_, &g); const char* e2 = R"P( @@ -652,8 +647,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n23 = Reshape[T=float, Tshape=int32](n22, n19) } )P"; - EXPECT_EQ(e2, DebugString(g)); - delete g; + EXPECT_EQ(e2, DebugString(g.get())); } namespace { diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index cbf16fa513a..cd4bf579c9c 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -21,6 +21,128 @@ limitations under the License. #include "tensorflow/core/graph/optimizer_cse.h" namespace tensorflow { +namespace { + +// Replaces occurrences of parallel_concat with the implementation based on +// unsafe ops. Sets removed_any to true if any parallel_concats were removed; +// leaves it untouched otherwise. +// TODO(apassos) Use NodeBuilder. +Status RemoveParallelConcat(bool* removed_any, Graph* g) { + gtl::InlinedVector<Node*, 2> matches; + for (Node* n : g->nodes()) { + if (n->type_string() == "ParallelConcat") { + matches.push_back(n); + } + } + for (Node* n : matches) { + AttrSlice n_attrs(n->def()); + auto make_node = [n, g, &n_attrs](string op) { + NodeDef node; + node.set_op(op); + node.set_name(g->NewName(n->name())); + node.set_device(n->def().device()); + string colo; + if (GetNodeAttr(n_attrs, "_class", &colo).ok()) { + AddNodeAttr("_class", colo, &node); + } + return node; + }; + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype)); + TensorShapeProto shape; + TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape)); + // Add the constant shape input to the start node. + NodeDef shape_node_def = make_node("Const"); + AddNodeAttr("dtype", DT_INT32, &shape_node_def); + TensorProto shape_tensor; + shape_tensor.set_dtype(DT_INT32); + shape_tensor.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size()); + for (int i = 0; i < shape.dim_size(); ++i) { + shape_tensor.add_int_val(shape.dim(i).size()); + } + AddNodeAttr("value", shape_tensor, &shape_node_def); + Status status = Status::OK(); + Node* shape_node = g->AddNode(shape_node_def, &status); + if (!status.ok()) return status; + + // Add the start node + NodeDef start_def = make_node("_ParallelConcatStart"); + AddNodeAttr("dtype", dtype, &start_def); + AddNodeAttr("Tshape", DT_INT32, &start_def); + AddNodeAttr("init", false, &start_def); + start_def.add_input(shape_node_def.name()); + Node* start = g->AddNode(start_def, &status); + if (!status.ok()) return status; + // TODO(apassos): make the shape an attr of _ParallelStackBegin. + g->AddEdge(shape_node, 0, start, 0); + + // Add all the inplace_updates. + std::vector<string> control_dependencies; + std::vector<Node*> control_nodes; + int i = 0; + for (const Edge* input_edge : n->in_edges()) { + if (input_edge->IsControlEdge()) { + g->AddControlEdge(input_edge->src(), start); + continue; + } + // Constant index for the update node. + // TODO(apassos): make _ParallelStackUpdate take this as an attr. + NodeDef update_idx_def = make_node("Const"); + AddNodeAttr("dtype", DT_INT64, &update_idx_def); + TensorProto index_tensor; + index_tensor.set_dtype(DT_INT64); + index_tensor.mutable_tensor_shape()->add_dim()->set_size(1); + index_tensor.add_int64_val(i); + AddNodeAttr("value", index_tensor, &update_idx_def); + Node* index = g->AddNode(update_idx_def, &status); + if (!status.ok()) return status; + + NodeDef update_def = make_node("_ParallelConcatUpdate"); + control_dependencies.push_back(update_def.name()); + AddNodeAttr("T", dtype, &update_def); + AddNodeAttr("Tshape", DT_INT64, &update_def); + update_def.add_input(start_def.name()); + update_def.add_input(update_idx_def.name()); + update_def.add_input(strings::StrCat(input_edge->src()->name(), ":", + input_edge->src_output())); + Node* update = g->AddNode(update_def, &status); + if (!status.ok()) return status; + g->AddEdge(start, 0, update, 0); + g->AddEdge(index, 0, update, 1); + g->AddEdge(input_edge->src(), input_edge->src_output(), update, 2); + control_nodes.push_back(update); + + ++i; + } + + // Add the final identity. + NodeDef identity_def = make_node("Identity"); + AddNodeAttr("T", dtype, &identity_def); + identity_def.add_input(start_def.name()); + for (const string& s : control_dependencies) { + identity_def.add_input(strings::StrCat("^", s)); + } + Node* identity_node = g->AddNode(identity_def, &status); + if (!status.ok()) return status; + g->AddEdge(start, 0, identity_node, 0); + for (Node* inp : control_nodes) { + g->AddControlEdge(inp, identity_node); + } + + // Remove the node and redirect edges. + for (auto* e : n->out_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(identity_node, e->dst()); + } else { + g->AddEdge(identity_node, 0, e->dst(), e->dst_input()); + } + } + g->RemoveNode(n); + *removed_any = true; + } + return Status::OK(); +} +} GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) { if (opts_.opt_level() >= OptimizerOptions::L1) { @@ -32,8 +154,8 @@ GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) { GraphOptimizer::~GraphOptimizer() {} void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, - Device* device, Graph** graph) { - Graph* g = *graph; + Device* device, std::unique_ptr<Graph>* graph) { + Graph* g = graph->get(); DumpGraph("Initial", g); bool changed = true; @@ -44,6 +166,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, DumpGraph("RemoveListArrayConverter", g); changed = true; } + auto s = RemoveParallelConcat(&changed, g); + if (!s.ok()) { + // TODO(apassos): figure out how to halt here. + LOG(WARNING) << s; + } if (opts_.do_function_inlining() && RemoveDeadNodes(g)) { DumpGraph("RemoveDeadNodes", g); changed = true; @@ -78,11 +205,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, if (!changed) break; } - Graph* copy = new Graph(g->op_registry()); - CopyGraph(*g, copy); - delete g; - *graph = copy; - DumpGraph("ReCopy", *graph); + std::unique_ptr<Graph> copy(new Graph(g->op_registry())); + CopyGraph(*g, copy.get()); + graph->swap(copy); + + DumpGraph("ReCopy", graph->get()); } } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 04618b229d8..a6b10356ce1 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -35,7 +35,7 @@ class GraphOptimizer { // optimizers so that they can respect constraints if any, that should be // respected. void Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device, - Graph** graph); + std::unique_ptr<Graph>* graph); private: OptimizerOptions opts_; diff --git a/tensorflow/core/common_runtime/optimization_registry.h b/tensorflow/core/common_runtime/optimization_registry.h index 45c571f847a..adfa17ae9d7 100644 --- a/tensorflow/core/common_runtime/optimization_registry.h +++ b/tensorflow/core/common_runtime/optimization_registry.h @@ -39,8 +39,17 @@ struct GraphOptimizationPassOptions { const CostModel* cost_model = nullptr; FunctionLibraryDefinition* flib_def = nullptr; // Not owned. + + // The graph to optimize, for optimization passes that run before + // partitioning. Null for post-partitioning passes. // An optimization pass may replace *graph with a new graph object. - std::unique_ptr<Graph>* graph; + std::unique_ptr<Graph>* graph = nullptr; + + // Graphs for each partition, if running post-partitioning. Optimization + // passes may alter the graphs, but must not add or remove partitions. + // Null for pre-partitioning passes. + std::unordered_map<string, std::unique_ptr<Graph>>* partition_graphs = + nullptr; }; // Optimization passes are implemented by inheriting from @@ -64,6 +73,7 @@ class OptimizationPassRegistry { PRE_PLACEMENT, // after cost model assignment, before placement. POST_PLACEMENT, // after placement. POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints. + POST_PARTITIONING, // after partitioning }; // Add an optimization pass to the registry. diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index d53f31d7fab..13ca471f464 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" @@ -138,37 +139,50 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions)); } + std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs; + for (const auto& partition : partitions) { + std::unique_ptr<Graph> device_graph(new Graph(item->lib_def)); + GraphConstructorOptions device_opts; + // There are internal operations (e.g., send/recv) that we now allow. + device_opts.allow_internal_ops = true; + device_opts.expect_device_spec = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, + device_graph.get())); + partition_graphs.emplace(partition.first, std::move(device_graph)); + } + + GraphOptimizationPassOptions optimization_options; + optimization_options.flib_def = item->lib_def; + optimization_options.partition_graphs = &partition_graphs; + TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); + LocalExecutorParams params; - Status s; item->units.reserve(partitions.size()); item->graph_mgr = this; const auto& optimizer_opts = graph_options.optimizer_options(); GraphOptimizer optimizer(optimizer_opts); - for (auto&& p : partitions) { + for (auto& p : partition_graphs) { const string& device_name = p.first; - GraphDef* def = &p.second; + std::unique_ptr<Graph>& subgraph = p.second; item->units.resize(item->units.size() + 1); ExecutionUnit* unit = &(item->units.back()); // Find the device. - s = worker_env_->device_mgr->LookupDevice(device_name, &unit->device); + Status s = + worker_env_->device_mgr->LookupDevice(device_name, &unit->device); if (!s.ok()) { // Remove the empty unit from the item as the item destructor wants all // units to have valid devices. item->units.pop_back(); - break; + return s; } - // Construct the subgraph. - Graph* subgraph = new Graph(item->lib_def); // Give the device an opportunity to rewrite its subgraph. - unit->device->MaybeRewriteGraph(gdef.library(), def); - s = ConvertGraphDefToGraph(opts, *def, subgraph); - if (!s.ok()) { - delete subgraph; - break; - } + TF_RETURN_IF_ERROR( + unit->device->MaybeRewriteGraph(gdef.library(), &subgraph)); + // Top-level nodes in the graph uses the op segment to cache // kernels. Therefore, as long as the executor is alive, we need // to ensure the kernels cached for the session are alive. @@ -178,7 +192,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, // Function library runtime. unit->lib = NewFunctionLibraryRuntime( worker_env_->device_mgr, worker_env_->env, unit->device, - def->versions().producer(), item->lib_def, + subgraph->versions().producer(), item->lib_def, graph_options.optimizer_options()); // Construct the root executor for the subgraph. @@ -207,23 +221,18 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, }; optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph); - s = EnsureMemoryTypes(DeviceType(unit->device->device_type()), - unit->device->name(), subgraph); - if (!s.ok()) { - delete subgraph; - break; - } - s = NewLocalExecutor(params, subgraph, &unit->root); - if (!s.ok()) { - break; - } - unit->graph = subgraph; + TF_RETURN_IF_ERROR( + EnsureMemoryTypes(DeviceType(unit->device->device_type()), + unit->device->name(), subgraph.get())); + unit->graph = subgraph.get(); unit->build_cost_model = graph_options.build_cost_model(); if (unit->build_cost_model > 0) { skip_cost_models_ = false; } + TF_RETURN_IF_ERROR( + NewLocalExecutor(params, subgraph.release(), &unit->root)); } - return s; + return Status::OK(); } Status GraphMgr::Register(const string& session, const GraphDef& gdef, diff --git a/tensorflow/core/distributed_runtime/local_master.cc b/tensorflow/core/distributed_runtime/local_master.cc index 684f2652a94..61ead9f31da 100644 --- a/tensorflow/core/distributed_runtime/local_master.cc +++ b/tensorflow/core/distributed_runtime/local_master.cc @@ -85,7 +85,7 @@ Status LocalMaster::PartialRunSetup(CallOptions* call_options, Status LocalMaster::RunStep(CallOptions* call_options, RunStepRequestWrapper* request, - RunStepResponse* response) { + MutableRunStepResponseWrapper* response) { Notification n; Status ret; master_impl_->RunStep(call_options, request, response, @@ -101,6 +101,10 @@ MutableRunStepRequestWrapper* LocalMaster::CreateRunStepRequest() { return new InMemoryRunStepRequest; } +MutableRunStepResponseWrapper* LocalMaster::CreateRunStepResponse() { + return new InMemoryRunStepResponse; +} + Status LocalMaster::CloseSession(CallOptions* call_options, const CloseSessionRequest* request, CloseSessionResponse* response) { diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h index 9c80bfdfed2..fe9cd9381c9 100644 --- a/tensorflow/core/distributed_runtime/local_master.h +++ b/tensorflow/core/distributed_runtime/local_master.h @@ -53,10 +53,12 @@ class LocalMaster : public MasterInterface { PartialRunSetupResponse* response) override; Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request, - RunStepResponse* response) override; + MutableRunStepResponseWrapper* response) override; MutableRunStepRequestWrapper* CreateRunStepRequest() override; + MutableRunStepResponseWrapper* CreateRunStepResponse() override; + Status CloseSession(CallOptions* call_options, const CloseSessionRequest* request, CloseSessionResponse* response) override; diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index 02d25937bc6..23fe908d43f 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -356,7 +356,7 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req, } void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req, - RunStepResponse* resp, MyClosure done) { + MutableRunStepResponseWrapper* resp, MyClosure done) { mu_.lock(); uint64 start_time = env_->env->NowMicros(); MasterSession* session = gtl::FindPtrOrNull(sessions_, req->session_handle()); diff --git a/tensorflow/core/distributed_runtime/master.h b/tensorflow/core/distributed_runtime/master.h index 742e8a2a79e..2bfebc1bfa6 100644 --- a/tensorflow/core/distributed_runtime/master.h +++ b/tensorflow/core/distributed_runtime/master.h @@ -50,7 +50,7 @@ class Master { PartialRunSetupResponse* resp, MyClosure done); void RunStep(CallOptions* opts, const RunStepRequestWrapper* req, - RunStepResponse* resp, MyClosure done); + MutableRunStepResponseWrapper* resp, MyClosure done); void CloseSession(const CloseSessionRequest* req, CloseSessionResponse* resp, MyClosure done); diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h index 5ddedc09543..bf6a2db3e27 100644 --- a/tensorflow/core/distributed_runtime/master_interface.h +++ b/tensorflow/core/distributed_runtime/master_interface.h @@ -24,10 +24,11 @@ limitations under the License. namespace tensorflow { -// Pure virtual interface for communicating with the TensorFlow Master service. +// Abstract interface for communicating with the TensorFlow Master service. // -// This interface is intended to support in-process master -// implementations that do not require an RPC roundtrip. +// This interface supports both RPC-based master implementations, and +// in-process master implementations that do not require an RPC +// roundtrip. class MasterInterface { public: virtual ~MasterInterface() {} @@ -47,20 +48,36 @@ class MasterInterface { virtual Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request, - RunStepResponse* response) = 0; + MutableRunStepResponseWrapper* response) = 0; virtual Status RunStep(CallOptions* call_options, const RunStepRequest* request, RunStepResponse* response) { std::unique_ptr<RunStepRequestWrapper> wrapped_request( new ProtoRunStepRequest(request)); - return RunStep(call_options, wrapped_request.get(), response); + std::unique_ptr<MutableRunStepResponseWrapper> wrapped_response( + new NonOwnedProtoRunStepResponse(response)); + return RunStep(call_options, wrapped_request.get(), wrapped_response.get()); } + // Returns a request object for use in calls to + // `RunStep()`. Ownership is transferred to the caller. + // + // The message returned from this method must only be used in a + // `RunStep()` call on the same `MasterInterface` instance. virtual MutableRunStepRequestWrapper* CreateRunStepRequest() { return new MutableProtoRunStepRequest; } + // Returns a response object for use in calls to + // `RunStep()`. Ownership is transferred to the caller. + // + // The message returned from this method must only be used in a + // `RunStep()` call on the same `MasterInterface` instance. + virtual MutableRunStepResponseWrapper* CreateRunStepResponse() { + return new OwnedProtoRunStepResponse; + } + virtual Status CloseSession(CallOptions* call_options, const CloseSessionRequest* request, CloseSessionResponse* response) = 0; @@ -71,6 +88,15 @@ class MasterInterface { virtual Status Reset(CallOptions* call_options, const ResetRequest* request, ResetResponse* response) = 0; + + protected: + // NOTE: This should only be called by implementations of this + // interface whose CreateRunStepResponse() method returns a + // proto-based wrappers for the RunStepResponse message. + RunStepResponse* get_proto_from_wrapper( + MutableRunStepResponseWrapper* wrapper) { + return wrapper->get_proto(); + } }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 58f97fefcec..52dce0557bf 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -182,7 +182,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { int64 execution_count, SimpleGraphExecutionState* execution_state, PerStepState* pss, CallOptions* opts, - const RunStepRequestWrapper& req, RunStepResponse* resp, + const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp, CancellationManager* cm, const bool is_last_partial_run); // Calls workers to cleanup states for the step "step_id". Calls @@ -193,7 +194,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { void ProcessStats(int64 step_id, PerStepState* pss, SimpleGraphExecutionState* execution_state, ProfileHandler* ph, const RunOptions& options, - RunStepResponse* resp); + RunMetadata* resp); void ProcessDeviceStats(ProfileHandler* ph, const SimpleGraphExecutionState* execution_state, const DeviceStepStats& ds, bool is_rpc); @@ -464,7 +465,7 @@ class RunManyGraphs { struct Call { CallOptions opts; std::unique_ptr<MutableRunGraphRequestWrapper> req; - RunGraphResponse resp; + std::unique_ptr<MutableRunGraphResponseWrapper> resp; }; Call* get(int index) { return &calls_[index]; } @@ -513,7 +514,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( const MasterEnv* env, int64 step_id, int64 execution_count, SimpleGraphExecutionState* execution_state, PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req, - RunStepResponse* resp, CancellationManager* cm, + MutableRunStepResponseWrapper* resp, CancellationManager* cm, const bool is_last_partial_run) { VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " << execution_count; @@ -550,6 +551,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( const Part& part = partitions_[i]; RunManyGraphs::Call* c = calls.get(i); c->req.reset(part.worker->CreateRunGraphRequest()); + c->resp.reset(part.worker->CreateRunGraphResponse()); if (is_partial_) { c->req->set_is_partial(is_partial_); c->req->set_is_last_partial_run(is_last_partial_run); @@ -614,7 +616,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( RunManyGraphs::Call* call = calls.get(i); TRACEPRINTF("Partition %d %s", i, part.name.c_str()); part.worker->RunGraphAsync( - &call->opts, call->req.get(), &call->resp, + &call->opts, call->req.get(), call->resp.get(), std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1)); } @@ -639,29 +641,28 @@ Status MasterSession::ReffedClientGraph::RunPartitions( if (status.ok()) { for (int i = 0; i < num; ++i) { const Part& part = partitions_[i]; - for (auto& recv : *(calls.get(i)->resp.mutable_recv())) { - auto* ret = resp->add_tensor(); - auto iter = part.key_fetch.find(recv.name()); + for (size_t j = 0; j < calls.get(i)->resp->num_recvs(); ++j) { + auto iter = part.key_fetch.find(calls.get(i)->resp->recv_key(j)); if (iter == part.key_fetch.end()) { - status.Update( - errors::Internal("Unexpected fetch key: ", recv.name())); + status.Update(errors::Internal("Unexpected fetch key: ", + calls.get(i)->resp->recv_key(j))); break; } const string& fetch = iter->second; - ret->set_name(fetch); - if (!CopyIfNeeded(recv.mutable_tensor(), ret->mutable_tensor())) { - status.Update( - errors::Internal("Unexpected unparseable tensor: ", recv.name())); + status.Update(resp->AddTensorFromRunGraphResponse( + fetch, calls.get(i)->resp.get(), j)); + if (!status.ok()) { break; } } - if (pss->collect_timeline && calls.get(i)->resp.has_step_stats()) { - pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats()); + if (pss->collect_timeline) { + pss->step_stats[i].Swap(calls.get(i)->resp->mutable_step_stats()); } - if (pss->collect_costs && calls.get(i)->resp.has_cost_graph()) { - for (int j = 0; j < calls.get(i)->resp.cost_graph().node_size(); ++j) { + if (pss->collect_costs) { + CostGraphDef* cost_graph = calls.get(i)->resp->mutable_cost_graph(); + for (int j = 0; j < cost_graph->node_size(); ++j) { resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap( - calls.get(i)->resp.mutable_cost_graph()->mutable_node(j)); + cost_graph->mutable_node(j)); } } } @@ -739,7 +740,7 @@ void MasterSession::ReffedClientGraph::CleanupPartitionsAsync( void MasterSession::ReffedClientGraph::ProcessStats( int64 step_id, PerStepState* pss, SimpleGraphExecutionState* execution_state, ProfileHandler* ph, - const RunOptions& options, RunStepResponse* resp) { + const RunOptions& options, RunMetadata* resp) { if (!pss->collect_costs && !pss->collect_timeline) return; // Out-of-band logging data is collected now, during post-processing. @@ -775,7 +776,7 @@ void MasterSession::ReffedClientGraph::ProcessStats( // Copy the stats back, but only for on-demand profiling to avoid slowing // down calls that trigger the automatic profiling. if (options.trace_level() == RunOptions::FULL_TRACE) { - resp->mutable_metadata()->mutable_step_stats()->Swap(&step_stats_proto); + resp->mutable_step_stats()->Swap(&step_stats_proto); } } } @@ -1171,7 +1172,7 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, } Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, - RunStepResponse* resp) { + MutableRunStepResponseWrapper* resp) { UpdateLastAccessTime(); { mutex_lock l(mu_); @@ -1239,7 +1240,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, - RunStepResponse* resp) { + MutableRunStepResponseWrapper* resp) { const string& prun_handle = req.partial_run_handle(); RunState* run_state = nullptr; { @@ -1328,7 +1329,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts, rcg->Ref(); rcg->ProcessStats(run_state->step_id, &run_state->pss, execution_state_.get(), run_state->ph.get(), - req.options(), resp); + req.options(), resp->mutable_metadata()); rcg->CleanupPartitionsAsync( run_state->step_id, [this, rcg, prun_handle](const Status& s) { if (!s.ok()) { @@ -1342,9 +1343,9 @@ Status MasterSession::DoPartialRun(CallOptions* opts, return s; } -Status MasterSession::DoRunWithLocalExecution(CallOptions* opts, - const RunStepRequestWrapper& req, - RunStepResponse* resp) { +Status MasterSession::DoRunWithLocalExecution( + CallOptions* opts, const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp) { VLOG(2) << "DoRunWithLocalExecution " << "req: " << req.DebugString(); PerStepState pss; @@ -1395,7 +1396,7 @@ Status MasterSession::DoRunWithLocalExecution(CallOptions* opts, // Schedule post-processing and cleanup to be done asynchronously. rcg->Ref(); rcg->ProcessStats(step_id, &pss, execution_state_.get(), ph.get(), - req.options(), resp); + req.options(), resp->mutable_metadata()); rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 8ab46539f35..4e78e08559f 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -80,7 +80,7 @@ class MasterSession : public core::RefCounted { // Run one step. Status Run(CallOptions* opts, const RunStepRequestWrapper& req, - RunStepResponse* resp); + MutableRunStepResponseWrapper* resp); // Close this session and delete "*this". Returns OK if all known // states are cleanup successfully. @@ -177,9 +177,9 @@ class MasterSession : public core::RefCounted { RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_); Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequestWrapper& req, - RunStepResponse* resp); + MutableRunStepResponseWrapper* resp); Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, - RunStepResponse* resp); + MutableRunStepResponseWrapper* resp); void UpdateLastAccessTime(); Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index 423d0057561..7b58feb93cc 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -17,6 +17,22 @@ limitations under the License. namespace tensorflow { +namespace { + +bool ParseTensorProtoToTensor(const TensorProto& tensor_proto, + Tensor* out_tensor) { + if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { + Tensor parsed(tensor_proto.dtype()); + if (parsed.FromProto(cpu_allocator(), tensor_proto)) { + *out_tensor = parsed; + return true; + } + } + return false; +} + +} // namespace + const string& InMemoryRunStepRequest::session_handle() const { return session_handle_; } @@ -38,13 +54,14 @@ const string& InMemoryRunStepRequest::feed_name(size_t i) const { return feeds_[i].first; } -Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* tensor) const { - *tensor = feeds_[i].second; +Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { + *out_tensor = feeds_[i].second; return Status::OK(); } -Status InMemoryRunStepRequest::FeedValue(size_t i, TensorProto* tensor) const { - feeds_[i].second.AsProtoTensorContent(tensor); +Status InMemoryRunStepRequest::FeedValue(size_t i, + TensorProto* out_tensor) const { + feeds_[i].second.AsProtoTensorContent(out_tensor); return Status::OK(); } @@ -117,21 +134,18 @@ size_t MutableProtoRunStepRequest::num_feeds() const { const string& MutableProtoRunStepRequest::feed_name(size_t i) const { return request_.feed(i).name(); } -Status MutableProtoRunStepRequest::FeedValue(size_t i, Tensor* tensor) const { - const TensorProto& tensor_proto = request_.feed(i).tensor(); - if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { - Tensor parsed(tensor_proto.dtype()); - if (parsed.FromProto(cpu_allocator(), tensor_proto)) { - *tensor = parsed; - return Status::OK(); - } +Status MutableProtoRunStepRequest::FeedValue(size_t i, + Tensor* out_tensor) const { + if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) { + return errors::InvalidArgument("Invalid TensorProto for feed value ", i); + } else { + return Status::OK(); } - return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } Status MutableProtoRunStepRequest::FeedValue(size_t i, - TensorProto* tensor) const { - *tensor = request_.feed(i).tensor(); + TensorProto* out_tensor) const { + *out_tensor = request_.feed(i).tensor(); return Status::OK(); } @@ -199,20 +213,16 @@ const string& ProtoRunStepRequest::feed_name(size_t i) const { return request_->feed(i).name(); } -Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* tensor) const { - const TensorProto& tensor_proto = request_->feed(i).tensor(); - if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { - Tensor parsed(tensor_proto.dtype()); - if (parsed.FromProto(cpu_allocator(), tensor_proto)) { - *tensor = parsed; - return Status::OK(); - } +Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { + if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) { + return errors::InvalidArgument("Invalid TensorProto for feed value ", i); + } else { + return Status::OK(); } - return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } -Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* tensor) const { - *tensor = request_->feed(i).tensor(); +Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const { + *out_tensor = request_->feed(i).tensor(); return Status::OK(); } @@ -361,15 +371,11 @@ const string& MutableProtoRunGraphRequest::send_key(size_t i) const { Status MutableProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { - const TensorProto& tensor_proto = request_.send(i).tensor(); - if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { - Tensor parsed(tensor_proto.dtype()); - if (parsed.FromProto(cpu_allocator(), tensor_proto)) { - *out_tensor = parsed; - return Status::OK(); - } + if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) { + return errors::InvalidArgument("Invalid TensorProto for feed value ", i); + } else { + return Status::OK(); } - return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest( @@ -434,15 +440,11 @@ const string& ProtoRunGraphRequest::send_key(size_t i) const { } Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { - const TensorProto& tensor_proto = request_->send(i).tensor(); - if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { - Tensor parsed(tensor_proto.dtype()); - if (parsed.FromProto(cpu_allocator(), tensor_proto)) { - *out_tensor = parsed; - return Status::OK(); - } + if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) { + return errors::InvalidArgument("Invalid TensorProto for feed value ", i); + } else { + return Status::OK(); } - return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } size_t ProtoRunGraphRequest::num_recvs() const { @@ -463,4 +465,228 @@ const RunGraphRequest& ProtoRunGraphRequest::ToProto() const { return *request_; } +size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); } + +const string& InMemoryRunGraphResponse::recv_key(size_t i) const { + return recvs_[i].first; +} + +Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) { + recvs_[i].second.AsProtoTensorContent(out_tensor); + return Status::OK(); +} + +Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { + *out_tensor = recvs_[i].second; + return Status::OK(); +} + +void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) { + recvs_.emplace_back(key, value); +} + +StepStats* InMemoryRunGraphResponse::mutable_step_stats() { + return &step_stats_; +} + +CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() { + return &cost_graph_; +} + +RunGraphResponse* InMemoryRunGraphResponse::get_proto() { + LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse"; +} + +size_t OwnedProtoRunGraphResponse::num_recvs() const { + return response_.recv_size(); +} + +const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const { + return response_.recv(i).name(); +} + +Status OwnedProtoRunGraphResponse::RecvValue(size_t i, + TensorProto* out_tensor) { + out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor()); + return Status::OK(); +} + +Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { + if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) { + return errors::InvalidArgument("Invalid TensorProto for recv value ", i); + } else { + return Status::OK(); + } +} + +void OwnedProtoRunGraphResponse::AddRecv(const string& key, + const Tensor& value) { + NamedTensorProto* recv = response_.add_recv(); + recv->set_name(key); + TensorProto* value_proto = recv->mutable_tensor(); + value.AsProtoTensorContent(value_proto); +} + +StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() { + return response_.mutable_step_stats(); +} + +CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() { + return response_.mutable_cost_graph(); +} + +RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; } + +NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse( + RunGraphResponse* response) + : response_(response) {} + +size_t NonOwnedProtoRunGraphResponse::num_recvs() const { + return response_->recv_size(); +} + +const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const { + return response_->recv(i).name(); +} + +Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, + TensorProto* out_tensor) { + out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor()); + return Status::OK(); +} + +Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { + if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) { + return errors::InvalidArgument("Invalid TensorProto for recv value ", i); + } else { + return Status::OK(); + } +} + +void NonOwnedProtoRunGraphResponse::AddRecv(const string& key, + const Tensor& value) { + NamedTensorProto* recv = response_->add_recv(); + recv->set_name(key); + TensorProto* value_proto = recv->mutable_tensor(); + value.AsProtoTensorContent(value_proto); +} + +StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() { + return response_->mutable_step_stats(); +} + +CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() { + return response_->mutable_cost_graph(); +} + +RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() { + return response_; +} + +MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {} + +size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); } + +const string& InMemoryRunStepResponse::tensor_name(size_t i) const { + return tensors_[i].first; +} + +Status InMemoryRunStepResponse::TensorValue(size_t i, + Tensor* out_tensor) const { + *out_tensor = tensors_[i].second; + return Status::OK(); +} + +const RunMetadata& InMemoryRunStepResponse::metadata() const { + return metadata_; +} + +Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) { + Tensor tensor; + TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor)); + tensors_.emplace_back(name, tensor); + return Status::OK(); +} + +RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; } + +RunStepResponse* InMemoryRunStepResponse::get_proto() { + LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse"; +} + +size_t OwnedProtoRunStepResponse::num_tensors() const { + return response_.tensor_size(); +} + +const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const { + return response_.tensor(i).name(); +} + +Status OwnedProtoRunStepResponse::TensorValue(size_t i, + Tensor* out_tensor) const { + if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) { + return errors::InvalidArgument("Invalid TensorProto for fetch value ", i); + } else { + return Status::OK(); + } +} + +const RunMetadata& OwnedProtoRunStepResponse::metadata() const { + return response_.metadata(); +} + +Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) { + NamedTensorProto* response_tensor = response_.add_tensor(); + response_tensor->set_name(name); + return run_graph_response->RecvValue(i, response_tensor->mutable_tensor()); +} + +RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() { + return response_.mutable_metadata(); +} + +RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; } + +NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse( + RunStepResponse* response) + : response_(response) {} + +size_t NonOwnedProtoRunStepResponse::num_tensors() const { + return response_->tensor_size(); +} + +const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const { + return response_->tensor(i).name(); +} + +Status NonOwnedProtoRunStepResponse::TensorValue(size_t i, + Tensor* out_tensor) const { + if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) { + return errors::InvalidArgument("Invalid TensorProto for fetch value ", i); + } else { + return Status::OK(); + } +} + +const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const { + return response_->metadata(); +} + +Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) { + NamedTensorProto* response_tensor = response_->add_tensor(); + response_tensor->set_name(name); + return run_graph_response->RecvValue(i, response_tensor->mutable_tensor()); +} + +RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() { + return response_->mutable_metadata(); +} + +RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; } + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 5e435591e6e..02516eabb4a 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -58,8 +58,8 @@ class RunStepRequestWrapper { virtual const string& feed_name(size_t i) const = 0; // Stores the content of the feed value at index `i` in `tensor`. - virtual Status FeedValue(size_t i, Tensor* tensor) const = 0; - virtual Status FeedValue(size_t i, TensorProto* tensor) const = 0; + virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0; + virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; // Fetches. A list of tensor names. The caller expects a tensor to // be returned for each fetch[i] (see RunStepResponse.tensor). The @@ -104,8 +104,8 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { const string& partial_run_handle() const override; size_t num_feeds() const override; const string& feed_name(size_t i) const override; - Status FeedValue(size_t i, Tensor* tensor) const override; - Status FeedValue(size_t i, TensorProto* tensor) const override; + Status FeedValue(size_t i, Tensor* out_tensor) const override; + Status FeedValue(size_t i, TensorProto* out_tensor) const override; size_t num_fetches() const override; const string& fetch_name(size_t i) const override; size_t num_targets() const override; @@ -151,8 +151,8 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { const string& partial_run_handle() const override; size_t num_feeds() const override; const string& feed_name(size_t i) const override; - Status FeedValue(size_t i, Tensor* tensor) const override; - Status FeedValue(size_t i, TensorProto* tensor) const override; + Status FeedValue(size_t i, Tensor* out_tensor) const override; + Status FeedValue(size_t i, TensorProto* out_tensor) const override; size_t num_fetches() const override; const string& fetch_name(size_t i) const override; size_t num_targets() const override; @@ -188,8 +188,8 @@ class ProtoRunStepRequest : public RunStepRequestWrapper { const string& partial_run_handle() const override; size_t num_feeds() const override; const string& feed_name(size_t i) const override; - Status FeedValue(size_t i, Tensor* tensor) const override; - Status FeedValue(size_t i, TensorProto* tensor) const override; + Status FeedValue(size_t i, Tensor* out_tensor) const override; + Status FeedValue(size_t i, TensorProto* out_tensor) const override; size_t num_fetches() const override; const string& fetch_name(size_t i) const override; size_t num_targets() const override; @@ -373,6 +373,242 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper { const RunGraphRequest* const request_; // Not owned. }; +//////////////////////////////////////////////////////////////////////////////// +// +// Wrapper classes for the `WorkerService.RunGraph` response message. +// +// The `RunGraphResponse` message can contain potentially large tensor +// data as part of its `recv` submessages. Here we provide specialized +// wrappers that avoid copying the tensor data wherever possible. +// +// See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the +// protocol buffer definition. +// +//////////////////////////////////////////////////////////////////////////////// + +// Abstract interface for a mutable RunGraphResponse message. +// +// Note that there is no corresponding (immutable) +// RunGraphResponseWrapper class, because the RunGraphResponse object +// is always used as a mutable pointer. +class MutableRunGraphResponseWrapper { + public: + virtual ~MutableRunGraphResponseWrapper() {} + + // A list of tensors corresponding to those requested by + // `RunGraphRequest.recv_key`. + virtual size_t num_recvs() const = 0; + virtual const string& recv_key(size_t i) const = 0; + // NOTE: The following methods may perform a destructive read, for + // efficiency. + virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0; + virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0; + virtual void AddRecv(const string& key, const Tensor& value) = 0; + + // Submessages that store performance statistics about the subgraph + // execution, if necessary. + virtual StepStats* mutable_step_stats() = 0; + virtual CostGraphDef* mutable_cost_graph() = 0; + + protected: + // Returns a mutable protobuf message that represents the contents of + // this wrapper, for passing to an RPC subsystem that will populate + // the message. + // + // NOTE: Only `WorkerInterface` subclasses may call this method. The + // `InMemoryRunGraphResponse` subclass does not implement this + // method, and attempts to call it will fail with a fatal + // error. However, as long as callers always call + // `WorkerInterface::RunGraphAsync()` with a wrapper object returned + // from `WorkerInterface::CreateRunGraphResponse()` called on the + // *same* WorkerInterface object, this error will never trigger. + virtual RunGraphResponse* get_proto() = 0; + friend class WorkerInterface; +}; + +class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { + public: + // MutableRunGraphResponseWrapper methods. + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + Status RecvValue(size_t i, TensorProto* out_tensor) override; + Status RecvValue(size_t i, Tensor* out_tensor) override; + void AddRecv(const string& key, const Tensor& value) override; + StepStats* mutable_step_stats() override; + CostGraphDef* mutable_cost_graph() override; + + protected: + // NOTE: This method is not implemented. See + // MutableRunGraphResponseWrapper for an explanation. + RunGraphResponse* get_proto() override; + + private: + gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_; + StepStats step_stats_; + CostGraphDef cost_graph_; +}; + +// Proto-based message wrapper for use on the client side of the RunGraph RPC. +class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { + public: + // MutableRunGraphResponseWrapper methods. + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + Status RecvValue(size_t i, TensorProto* out_tensor) override; + Status RecvValue(size_t i, Tensor* out_tensor) override; + void AddRecv(const string& key, const Tensor& value) override; + StepStats* mutable_step_stats() override; + CostGraphDef* mutable_cost_graph() override; + + protected: + RunGraphResponse* get_proto() override; + + private: + RunGraphResponse response_; +}; + +// Proto-based message wrapper for use on the server side of the RunGraph RPC. +class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { + public: + NonOwnedProtoRunGraphResponse(RunGraphResponse* response); + + // MutableRunGraphResponseWrapper methods. + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + Status RecvValue(size_t i, TensorProto* out_tensor) override; + Status RecvValue(size_t i, Tensor* out_tensor) override; + void AddRecv(const string& key, const Tensor& value) override; + StepStats* mutable_step_stats() override; + CostGraphDef* mutable_cost_graph() override; + + protected: + RunGraphResponse* get_proto() override; + + private: + RunGraphResponse* const response_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Wrapper classes for the `MasterService.RunStep` response message. +// +// The `RunStepResponse` message can contain potentially large tensor +// data as part of its `tensor` submessages. Here we provide specialized +// wrappers that avoid copying the tensor data wherever possible. +// +// See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the +// protocol buffer definition. +// +//////////////////////////////////////////////////////////////////////////////// + +// Abstract interface for a mutable RunStepResponse message. +// +// Note that there is no corresponding (immutable) +// RunStepResponseWrapper class, because the RunStepResponse object is +// always used as a mutable pointer. +class MutableRunStepResponseWrapper { + public: + virtual ~MutableRunStepResponseWrapper(); + + // The values of the tensors whose fetching was requested in the + // RunStep call. + // + // NOTE: The order of the returned tensors may or may not match + // the fetch order specified in RunStepRequest. + virtual size_t num_tensors() const = 0; + virtual const string& tensor_name(size_t i) const = 0; + virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0; + + // Stores the i^{th} recv value in `run_graph_response` in this + // response with the given `name`. + virtual Status AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) = 0; + + // Returned metadata if requested in the options. + virtual const RunMetadata& metadata() const = 0; + virtual RunMetadata* mutable_metadata() = 0; + + protected: + // Returns a mutable protobuf message that represents the contents of + // this wrapper, for passing to an RPC subsystem that will populate + // the message. + // + // NOTE: Only `MasterInterface` subclasses may call this method. The + // `InMemoryRunStepResponse` subclass does not implement this + // method, and attempts to call it will fail with a fatal + // error. However, as long as callers always call + // `MasterInterface::RunStep()` with a wrapper object returned + // from `MasterInterface::CreateRunStepResponse()` called on the + // *same* MasterInterface object, this error will never trigger. + virtual RunStepResponse* get_proto() = 0; + friend class MasterInterface; +}; + +class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { + public: + // MutableRunStepResponseWrapper methods. + size_t num_tensors() const override; + const string& tensor_name(size_t i) const override; + Status TensorValue(size_t i, Tensor* out_tensor) const override; + Status AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) override; + const RunMetadata& metadata() const override; + RunMetadata* mutable_metadata() override; + + protected: + // NOTE: This method is not implemented. See + // MutableRunGraphResponseWrapper for an explanation. + RunStepResponse* get_proto() override; + + private: + gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_; + RunMetadata metadata_; +}; + +// Proto-based message wrapper for use on the client side of the RunStep RPC. +class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { + public: + // MutableRunStepResponseWrapper methods. + size_t num_tensors() const override; + const string& tensor_name(size_t i) const override; + Status TensorValue(size_t i, Tensor* out_tensor) const override; + Status AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) override; + const RunMetadata& metadata() const override; + RunMetadata* mutable_metadata() override; + + protected: + RunStepResponse* get_proto() override; + + private: + RunStepResponse response_; +}; + +// Proto-based message wrapper for use on the server side of the RunStep RPC. +class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { + public: + NonOwnedProtoRunStepResponse(RunStepResponse* response); + + // MutableRunStepResponseWrapper methods. + size_t num_tensors() const override; + const string& tensor_name(size_t i) const override; + Status TensorValue(size_t i, Tensor* out_tensor) const override; + Status AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) override; + const RunMetadata& metadata() const override; + RunMetadata* mutable_metadata() override; + + protected: + RunStepResponse* get_proto() override; + + private: + RunStepResponse* response_; // Not owned. +}; + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW diff --git a/tensorflow/core/distributed_runtime/message_wrappers_test.cc b/tensorflow/core/distributed_runtime/message_wrappers_test.cc index 6b3f2f81352..00ccec4fdf0 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers_test.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers_test.cc @@ -92,6 +92,55 @@ static void CheckRunGraphRequest(const RunGraphRequestWrapper& request) { EXPECT_FALSE(request.is_last_partial_run()); } +static void BuildRunGraphResponse( + MutableRunGraphResponseWrapper* run_graph_response) { + run_graph_response->AddRecv("recv_2", TensorA()); + run_graph_response->AddRecv("recv_3", TensorB()); + run_graph_response->mutable_step_stats()->add_dev_stats()->set_device( + "/cpu:0"); + run_graph_response->mutable_cost_graph()->add_node()->set_name("cost_node"); +} + +static void CheckRunGraphResponse(MutableRunGraphResponseWrapper* response) { + EXPECT_EQ(2, response->num_recvs()); + EXPECT_EQ("recv_2", response->recv_key(0)); + EXPECT_EQ("recv_3", response->recv_key(1)); + Tensor val; + response->RecvValue(0, &val); + test::ExpectTensorEqual<int32>(TensorA(), val); + response->RecvValue(1, &val); + test::ExpectTensorEqual<int32>(TensorB(), val); + EXPECT_EQ(1, response->mutable_step_stats()->dev_stats_size()); + EXPECT_EQ("/cpu:0", response->mutable_step_stats()->dev_stats(0).device()); + EXPECT_EQ(1, response->mutable_cost_graph()->node_size()); + EXPECT_EQ("cost_node", response->mutable_cost_graph()->node(0).name()); +} + +static void BuildRunStepResponse( + MutableRunGraphResponseWrapper* run_graph_response, + MutableRunStepResponseWrapper* run_step_response) { + run_step_response->AddTensorFromRunGraphResponse("fetch_x:0", + run_graph_response, 0); + run_step_response->AddTensorFromRunGraphResponse("fetch_y:0", + run_graph_response, 1); + *run_step_response->mutable_metadata()->mutable_step_stats() = + *run_graph_response->mutable_step_stats(); +} + +static void CheckRunStepResponse( + const MutableRunStepResponseWrapper& response) { + EXPECT_EQ(2, response.num_tensors()); + EXPECT_EQ("fetch_x:0", response.tensor_name(0)); + EXPECT_EQ("fetch_y:0", response.tensor_name(1)); + Tensor val; + response.TensorValue(0, &val); + test::ExpectTensorEqual<int32>(TensorA(), val); + response.TensorValue(1, &val); + test::ExpectTensorEqual<int32>(TensorB(), val); + EXPECT_EQ(1, response.metadata().step_stats().dev_stats_size()); + EXPECT_EQ("/cpu:0", response.metadata().step_stats().dev_stats(0).device()); +} + TEST(MessageWrappers, RunStepRequest_Basic) { InMemoryRunStepRequest in_memory_request; BuildRunStepRequest(&in_memory_request); @@ -164,4 +213,108 @@ TEST(MessageWrappers, RunGraphRequest_Basic) { } } +TEST(MessageWrappers, RunGraphResponse_Basic) { + InMemoryRunGraphResponse in_memory_response; + BuildRunGraphResponse(&in_memory_response); + CheckRunGraphResponse(&in_memory_response); + + OwnedProtoRunGraphResponse owned_proto_response; + BuildRunGraphResponse(&owned_proto_response); + CheckRunGraphResponse(&owned_proto_response); + + RunGraphResponse response_proto; + NonOwnedProtoRunGraphResponse non_owned_proto_response(&response_proto); + BuildRunGraphResponse(&non_owned_proto_response); + CheckRunGraphResponse(&non_owned_proto_response); +} + +TEST(MessageWrappers, RunStepResponse_Basic) { + { + // Worker -(in memory)-> Master -(in memory)-> Client. + InMemoryRunGraphResponse run_graph_response; + BuildRunGraphResponse(&run_graph_response); + InMemoryRunStepResponse response; + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } + + { + // Worker -(in memory)-> Master -(owned proto)-> Client. + InMemoryRunGraphResponse run_graph_response; + BuildRunGraphResponse(&run_graph_response); + OwnedProtoRunStepResponse response; + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } + + { + // Worker -(in memory)-> Master -(non-owned proto)-> Client. + InMemoryRunGraphResponse run_graph_response; + BuildRunGraphResponse(&run_graph_response); + RunStepResponse response_proto; + NonOwnedProtoRunStepResponse response(&response_proto); + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } + + { + // Worker -(owned proto)-> Master -(in memory)-> Client. + OwnedProtoRunGraphResponse run_graph_response; + BuildRunGraphResponse(&run_graph_response); + InMemoryRunStepResponse response; + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } + + { + // Worker -(owned proto)-> Master -(owned proto)-> Client. + OwnedProtoRunGraphResponse run_graph_response; + BuildRunGraphResponse(&run_graph_response); + OwnedProtoRunStepResponse response; + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } + + { + // Worker -(owned proto)-> Master -(non-owned proto)-> Client. + OwnedProtoRunGraphResponse run_graph_response; + BuildRunGraphResponse(&run_graph_response); + RunStepResponse response_proto; + NonOwnedProtoRunStepResponse response(&response_proto); + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } + + { + // Worker -(non-owned proto)-> Master -(in memory)-> Client. + RunGraphResponse run_graph_response_proto; + NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto); + BuildRunGraphResponse(&run_graph_response); + InMemoryRunStepResponse response; + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } + + { + // Worker -(non-owned proto)-> Master -(owned proto)-> Client. + RunGraphResponse run_graph_response_proto; + NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto); + BuildRunGraphResponse(&run_graph_response); + OwnedProtoRunStepResponse response; + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } + + { + // Worker -(non-owned proto)-> Master -(non-owned proto)-> Client. + RunGraphResponse run_graph_response_proto; + NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto); + BuildRunGraphResponse(&run_graph_response); + RunStepResponse response_proto; + NonOwnedProtoRunStepResponse response(&response_proto); + BuildRunStepResponse(&run_graph_response, &response); + CheckRunStepResponse(response); + } +} + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index ff7fc71d5fa..096ed3c0d2e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -174,15 +174,17 @@ class GrpcMasterService : public AsyncServiceInterface { CallOptions* call_opts = new CallOptions; RunStepRequestWrapper* wrapped_request = new ProtoRunStepRequest(&call->request); + MutableRunStepResponseWrapper* wrapped_response = + new NonOwnedProtoRunStepResponse(&call->response); call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - master_impl_->RunStep( - call_opts, wrapped_request, &call->response, - [call, call_opts, wrapped_request](const Status& status) { - call->ClearCancelCallback(); - delete call_opts; - delete wrapped_request; - call->SendResponse(ToGrpcStatus(status)); - }); + master_impl_->RunStep(call_opts, wrapped_request, wrapped_response, + [call, call_opts, wrapped_request, + wrapped_response](const Status& status) { + call->ClearCancelCallback(); + delete call_opts; + delete wrapped_request; + call->SendResponse(ToGrpcStatus(status)); + }); ENQUEUE_REQUEST(RunStep, true); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index 5b991043fb9..c3b76ed31bc 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -62,11 +62,12 @@ class GrpcRemoteMaster : public MasterInterface { } Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request, - RunStepResponse* response) override { + MutableRunStepResponseWrapper* response) override { ::grpc::ClientContext ctx; ctx.set_fail_fast(false); SetDeadline(&ctx, call_options->GetTimeout()); - return FromGrpcStatus(stub_->RunStep(&ctx, request->ToProto(), response)); + return FromGrpcStatus(stub_->RunStep(&ctx, request->ToProto(), + get_proto_from_wrapper(response))); } Status CloseSession(CallOptions* call_options, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index eeb026016b6..7e83301841e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -75,9 +75,10 @@ class GrpcRemoteWorker : public WorkerInterface { IssueRequest(request, response, rungraph_, std::move(done), call_opts); } void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request, - RunGraphResponse* response, StatusCallback done) override { - IssueRequest(&request->ToProto(), response, rungraph_, std::move(done), - call_opts); + MutableRunGraphResponseWrapper* response, + StatusCallback done) override { + IssueRequest(&request->ToProto(), get_proto_from_wrapper(response), + rungraph_, std::move(done), call_opts); } void CleanupGraphAsync(const CleanupGraphRequest* request, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 61b6475fedc..4e2f5de2139 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -171,7 +171,8 @@ Status GrpcSession::RunHelper( // Convert to proto std::unique_ptr<MutableRunStepRequestWrapper> req( master_->CreateRunStepRequest()); - RunStepResponse resp; + std::unique_ptr<MutableRunStepResponseWrapper> resp( + master_->CreateRunStepResponse()); *req->mutable_options() = run_options; @@ -196,31 +197,27 @@ Status GrpcSession::RunHelper( CallOptions call_options; call_options.SetTimeout(run_options.timeout_in_ms()); - TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), &resp)); + TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get())); if (!output_tensor_names.empty()) { outputs->resize(output_tensor_names.size()); } // Convert response back to Tensors in the correct order. - for (const NamedTensorProto& tensor : resp.tensor()) { - auto fetch_it = output_name_to_offset.find(tensor.name()); + for (size_t i = 0; i < resp->num_tensors(); ++i) { + auto fetch_it = output_name_to_offset.find(resp->tensor_name(i)); if (fetch_it == output_name_to_offset.end()) { return errors::Internal("Received response for unrequested fetch: ", - tensor.name()); + resp->tensor_name(i)); } Tensor output; - if (!output.FromProto(tensor.tensor())) { - return errors::InvalidArgument("Could not parse returned proto for ", - tensor.name()); - } - + TF_RETURN_IF_ERROR(resp->TensorValue(i, &output)); (*outputs)[fetch_it->second] = output; } if (run_metadata) { - run_metadata->Swap(resp.mutable_metadata()); + run_metadata->Swap(resp->mutable_metadata()); } return Status::OK(); @@ -248,7 +245,7 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, Status GrpcSession::RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req, - RunStepResponse* resp) { + MutableRunStepResponseWrapper* resp) { { mutex_lock l(mu_); if (handle_.empty()) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h index bc1bab737d1..8fd17a053be 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h @@ -119,7 +119,7 @@ class GrpcSession : public Session { const string& prun_handle); Status RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req, - RunStepResponse* resp); + MutableRunStepResponseWrapper* resp); // Implementations for all the public interfaces. Status CreateImpl(CallOptions* call_options, const GraphDef& graph); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 2df9cf5d70b..202b35a9d7b 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -215,15 +215,18 @@ class GrpcWorkerService : public AsyncServiceInterface { CallOptions* call_opts = new CallOptions; ProtoRunGraphRequest* wrapped_request = new ProtoRunGraphRequest(&call->request); + NonOwnedProtoRunGraphResponse* wrapped_response = + new NonOwnedProtoRunGraphResponse(&call->response); call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - worker_->RunGraphAsync( - call_opts, wrapped_request, &call->response, - [call, call_opts, wrapped_request](const Status& s) { - call->ClearCancelCallback(); - delete call_opts; - delete wrapped_request; - call->SendResponse(ToGrpcStatus(s)); - }); + worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response, + [call, call_opts, wrapped_request, + wrapped_response](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + delete wrapped_request; + delete wrapped_response; + call->SendResponse(ToGrpcStatus(s)); + }); }); ENQUEUE_REQUEST(RunGraph, true); } diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index a19ad87acc5..e59e880af37 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -110,7 +110,8 @@ Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req, } void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, - RunGraphResponse* response, StatusCallback done) { + MutableRunGraphResponseWrapper* response, + StatusCallback done) { if (request->is_partial()) { DoPartialRunGraph(opts, request, response, std::move(done)); } else { @@ -122,8 +123,13 @@ MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() { return new InMemoryRunGraphRequest; } +MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() { + return new InMemoryRunGraphResponse; +} + void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, - RunGraphResponse* response, StatusCallback done) { + MutableRunGraphResponseWrapper* response, + StatusCallback done) { const int64 step_id = request->step_id(); TRACEPRINTF("RunGraph: %lld", step_id); GraphMgr::NamedTensors in; @@ -179,11 +185,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, for (const auto& p : *out) { const string& key = p.first; const Tensor& val = p.second; - auto* recv = response->add_recv(); - recv->set_name(key); - // TODO(zhifengc): Deal with gpu -> cpu copy. - TensorProto* proto = recv->mutable_tensor(); - val.AsProtoTensorContent(proto); + response->AddRecv(key, val); } } delete collector; @@ -195,7 +197,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, // TODO(suharshs): Add stats collection support to partial run. void Worker::DoPartialRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, - RunGraphResponse* response, + MutableRunGraphResponseWrapper* response, StatusCallback done) { const int64 step_id = request->step_id(); const string& graph_handle = request->graph_handle(); @@ -276,11 +278,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts, for (const auto& p : *out) { const string& key = p.first; const Tensor& val = p.second; - auto* recv = response->add_recv(); - recv->set_name(key); - // TODO(zhifengc): Deal with gpu -> cpu copy. - TensorProto* proto = recv->mutable_tensor(); - val.AsProtoTensorContent(proto); + response->AddRecv(key, val); } // If this is the last partial run request we must also wait for the entire diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index a9df911f5b9..b52a809a0ea 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -54,10 +54,13 @@ class Worker : public WorkerInterface { StatusCallback done) override; void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, - RunGraphResponse* response, StatusCallback done) override; + MutableRunGraphResponseWrapper* response, + StatusCallback done) override; MutableRunGraphRequestWrapper* CreateRunGraphRequest() override; + MutableRunGraphResponseWrapper* CreateRunGraphResponse() override; + void CleanupGraphAsync(const CleanupGraphRequest* request, CleanupGraphResponse* response, StatusCallback done) override; @@ -117,10 +120,12 @@ class Worker : public WorkerInterface { GraphMgr::NamedTensors* out); void DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, - RunGraphResponse* response, StatusCallback done); + MutableRunGraphResponseWrapper* response, + StatusCallback done); void DoPartialRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, - RunGraphResponse* response, StatusCallback done); + MutableRunGraphResponseWrapper* response, + StatusCallback done); TF_DISALLOW_COPY_AND_ASSIGN(Worker); }; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index 577ecf25edd..6de432ea0d4 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -49,7 +49,7 @@ class WorkerInterface { StatusCallback done) = 0; virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, - RunGraphResponse* repsonse, + MutableRunGraphResponseWrapper* repsonse, StatusCallback done) = 0; virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request, @@ -57,17 +57,34 @@ class WorkerInterface { // TODO(mrry): Convert this to std::bind/std::move if the overhead // of std::function copying becomes too much. RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request); - RunGraphAsync(opts, wrapped_request, response, - [wrapped_request, done](const Status& s) { + MutableRunGraphResponseWrapper* wrapped_response = + new NonOwnedProtoRunGraphResponse(response); + RunGraphAsync(opts, wrapped_request, wrapped_response, + [wrapped_request, wrapped_response, done](const Status& s) { done(s); delete wrapped_request; + delete wrapped_response; }); } + // Returns a request object for use in calls to + // `RunGraphAsync()`. Ownership is transferred to the caller. + // + // The message returned from this method must only be used in a + // `RunGraph()` call on the same `WorkerInterface` instance. virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() { return new MutableProtoRunGraphRequest; } + // Returns a response object for use in calls to + // `RunGraphAsync()`. Ownership is transferred to the caller. + // + // The message returned from this method must only be used in a + // `RunGraph()` call on the same `WorkerInterface` instance. + virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() { + return new OwnedProtoRunGraphResponse; + } + virtual void CleanupGraphAsync(const CleanupGraphRequest* request, CleanupGraphResponse* response, StatusCallback done) = 0; @@ -126,6 +143,14 @@ class WorkerInterface { virtual ~WorkerInterface() {} friend class WorkerCacheInterface; + // NOTE: This should only be called by implementations of this + // interface whose CreateRunGraphResponse() method returns a + // proto-based wrappers for the RunGraphResponse message. + RunGraphResponse* get_proto_from_wrapper( + MutableRunGraphResponseWrapper* wrapper) { + return wrapper->get_proto(); + } + private: typedef WorkerInterface ME; diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index f6475e07366..02fdc16e882 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" @@ -80,8 +81,7 @@ InferenceContext::InferenceContext( PostInputInit(input_handle_shapes, input_handle_dtypes); } -InferenceContext::~InferenceContext() { -} +InferenceContext::~InferenceContext() {} Status InferenceContext::set_output(StringPiece output_name, const std::vector<ShapeHandle>& shapes) { @@ -231,6 +231,11 @@ string InferenceContext::DebugString(DimensionHandle d) { return ValueKnown(d) ? strings::StrCat(Value(d)) : "?"; } +string InferenceContext::DebugString() const { + return strings::StrCat("InferenceContext for node: ", + ProtoDebugString(node_def_)); +} + Status InferenceContext::WithRank(ShapeHandle shape, int32 rank, ShapeHandle* out) { const int32 existing = Rank(shape); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index d91775152c2..704fe2e848e 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -259,6 +259,9 @@ class InferenceContext { string DebugString(ShapeHandle s); string DebugString(DimensionHandle d); + // Describes the whole context, for debugging purposes. + string DebugString() const; + // If <shape> has rank <rank>, or its rank is unknown, return OK and return // the shape with asserted rank in <*out>. Otherwise return an error. // diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 7652f076482..d1550512737 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -66,6 +66,7 @@ class GraphConstructor { ? in.prefix : in.prefix + "/"), input_map(in.input_map), + control_dependencies(in.control_dependencies), importing(true) {} bool allow_internal_ops; @@ -73,6 +74,7 @@ class GraphConstructor { string prefix; std::map<TensorId, TensorId> input_map; + std::vector<string> control_dependencies; // TODO(ashankar): This bool exists to separate out functionality required // to make ImportGraphDef a close equivalent of Python's import_graph_def @@ -107,7 +109,7 @@ class GraphConstructor { Status TryImport() { TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); - TF_RETURN_IF_ERROR(ValidateInputMap()); + TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies()); TF_RETURN_IF_ERROR(BuildNodeIndex()); TF_RETURN_IF_ERROR(InitFromEdges()); TF_RETURN_IF_ERROR(Convert()); @@ -118,7 +120,7 @@ class GraphConstructor { } Status EnsureNoNameCollisions(); - Status ValidateInputMap(); + Status ValidateInputMapAndControlDependencies(); Status BuildNodeIndex(); Status InitFromEdges(); Status Convert(); @@ -132,11 +134,18 @@ class GraphConstructor { Status MakeEdge(Node* src, int output_index, Node* dst, int input_index); Status ValidateShape(Node* node); Status ModifyNodeDefForImport(NodeDef* node_def); - // Modifies node_def's inputs according to opts_.input_map. input_remapped is - // a pre-initialized vector of length node_def->input_size() indicating - // whether each input has been remapped. - void RemapNodeDefInputs(NodeDef* node_def, std::vector<bool>* input_remapped); - void AddPrefixToNodeDef(const std::vector<bool>& input_remapped, + // Modifies node_def's inputs according to opts_.input_map. + // input_already_exists is a pre-initialized vector of length + // node_def->input_size(). This function will mark inputs that are remapped to + // true. + void RemapNodeDefInputs(NodeDef* node_def, + std::vector<bool>* input_already_exists); + // input_already_exists is a pre-initialized vector of length + // node_def->input_size(). This function will add and mark control inputs as + // true. + void AddControlDependencies(NodeDef* node_def, + std::vector<bool>* input_already_exists); + void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists, NodeDef* node_def); // From constructor @@ -208,14 +217,27 @@ bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map, return false; } +bool NodeNameInValues(const std::vector<string>& control_dependencies, + const StringPiece& node_name) { + return std::find(control_dependencies.begin(), control_dependencies.end(), + node_name) != control_dependencies.end(); +} + Status GraphConstructor::EnsureNoNameCollisions() { existing_nodes_.reserve(g_->num_nodes()); for (Node* n : g_->nodes()) { bool already_exists = !existing_nodes_.insert({n->name(), n}).second; - if (already_exists && NodeNameInValues(opts_.input_map, n->name())) { - return errors::InvalidArgument( - "cannot resolve input_map because multiple nodes exist with name '", - n->name(), "'"); + if (already_exists) { + if (NodeNameInValues(opts_.input_map, n->name())) { + return errors::InvalidArgument( + "cannot resolve input_map because multiple nodes exist with name '", + n->name(), "'"); + } + if (NodeNameInValues(opts_.control_dependencies, n->name())) { + return errors::InvalidArgument( + "cannot resolve control_dependencies because multiple nodes exist " + "with name '", n->name(), "'"); + } } } if (opts_.prefix.empty() && opts_.importing) { @@ -248,7 +270,7 @@ Status GraphConstructor::EnsureNoNameCollisions() { return Status::OK(); } -Status GraphConstructor::ValidateInputMap() { +Status GraphConstructor::ValidateInputMapAndControlDependencies() { for (const auto& mapping : opts_.input_map) { TensorId src = mapping.first; TensorId dst = mapping.second; @@ -264,6 +286,13 @@ Status GraphConstructor::ValidateInputMap() { "control edge and non-control edge"); } } + for (const string& node : opts_.control_dependencies) { + if (existing_nodes_.count(node) == 0) { + return errors::InvalidArgument( + "node '", node, "' in control_dependencies does not exist in " + "graph"); + } + } return Status::OK(); } @@ -466,9 +495,9 @@ void RemoveInputs(NodeDef* node_def, const std::vector<int>& inputs_to_remove) { } } -void GraphConstructor::RemapNodeDefInputs(NodeDef* node_def, - std::vector<bool>* input_remapped) { - DCHECK_EQ(input_remapped->size(), node_def->input_size()); +void GraphConstructor::RemapNodeDefInputs( + NodeDef* node_def, std::vector<bool>* input_already_exists) { + DCHECK_EQ(input_already_exists->size(), node_def->input_size()); std::set<TensorId> control_inputs; std::vector<int> inputs_to_remove; @@ -487,13 +516,52 @@ void GraphConstructor::RemapNodeDefInputs(NodeDef* node_def, control_inputs.insert(new_input); } node_def->set_input(i, new_input.ToString()); - (*input_remapped)[i] = true; + (*input_already_exists)[i] = true; } if (!inputs_to_remove.empty()) RemoveInputs(node_def, inputs_to_remove); } +void GraphConstructor::AddControlDependencies( + NodeDef* node_def, std::vector<bool>* input_already_exists) { + // To avoid adding redundant control dependencies to every imported node, skip + // nodes that will inherit the dependencies from another imported node. + bool inherits_deps = false; + for (int i = 0; i < node_def->input_size(); ++i) { + // Assume we won't inherit dependencies from remapped inputs that already + // exist in the graph. Even if we're wrong, we'll only add redundant + // dependencies. + if ((*input_already_exists)[i]) continue; + + // If this input is a backedge, assume we won't inherit the dependencies. + // TODO(skyewm): we have many redundant ParseTensorName calls. It could be + // worth optimizing these. + TensorId id(ParseTensorName(node_def->input(i))); + auto iter = gdef_nodes_.find(id.first); + DCHECK(iter != gdef_nodes_.end()) << id.first; + if (iter->second.node == nullptr) { + // Input hasn't been created yet, indicating it's a backedge. + continue; + } + inherits_deps = true; + } + if (inherits_deps) return; + + // node_def either has no inputs or all remapped inputs, add the control + // dependencies + for (const string& control_dep : opts_.control_dependencies) { + string input = TensorId(control_dep, Graph::kControlSlot).ToString(); + const protobuf::RepeatedPtrField<string>& inputs = node_def->input(); + if (std::find(inputs.begin(), inputs.end(), input) != inputs.end()) { + // Control dependency already exists + continue; + } + node_def->add_input(input); + input_already_exists->push_back(true); + } +} + void GraphConstructor::AddPrefixToNodeDef( - const std::vector<bool>& input_remapped, NodeDef* node_def) { + const std::vector<bool>& input_already_exists, NodeDef* node_def) { const string& prefix = opts_.prefix; if (prefix.empty()) return; node_def->set_name(strings::StrCat(prefix, node_def->name())); @@ -502,7 +570,7 @@ void GraphConstructor::AddPrefixToNodeDef( StringPiece input(node_def->input(i)); // Skip remapped inputs (which already exist in g_ and are not being // imported) - if (input_remapped[i]) continue; + if (input_already_exists[i]) continue; if (input.Consume("^")) { node_def->set_input(i, strings::StrCat("^", prefix, input)); } else { @@ -540,7 +608,12 @@ Status GraphConstructor::Convert() { NodeDef imported_node_def; const NodeDef* node_def; - std::vector<bool> input_remapped(original_node_def.input_size(), false); + // input_already_exists[i] is true iff the i-th input of the node we're + // importing refers to a preexisting node in g_ (i.e. input[i] existed prior + // to importing gdef_). Conversely, input_already_exists[i] is false iff + // the input refers to a node in gdef_. + std::vector<bool> input_already_exists(original_node_def.input_size(), + false); if (opts_.importing) { // TODO(ashankar): The line below means an additional copy of the NodeDef, @@ -549,7 +622,11 @@ Status GraphConstructor::Convert() { // GraphDef* and avoid the copying. imported_node_def = original_node_def; if (!opts_.input_map.empty()) { - RemapNodeDefInputs(&imported_node_def, &input_remapped); + RemapNodeDefInputs(&imported_node_def, &input_already_exists); + } + if (!opts_.control_dependencies.empty()) { + // Note that input_already_exists can grow here + AddControlDependencies(&imported_node_def, &input_already_exists); } node_def = &imported_node_def; } else { @@ -562,7 +639,7 @@ Status GraphConstructor::Convert() { Node* src_node; int src_index; - if (!input_remapped[i]) { + if (!input_already_exists[i]) { // Locate input in newly-imported nodes auto iter = gdef_nodes_.find(id.first); DCHECK(iter != gdef_nodes_.end()) << id.first; @@ -570,7 +647,7 @@ Status GraphConstructor::Convert() { src_index = id.second; if (src_node == nullptr) has_data_back_edge = true; } else { - // Input was remapped according to input_map + // Input refers to preexistng node in graph auto iter = existing_nodes_.find(id.first); DCHECK(iter != existing_nodes_.end()) << id.first; src_node = iter->second; @@ -595,7 +672,7 @@ Status GraphConstructor::Convert() { Node* node; if (opts_.importing) { - AddPrefixToNodeDef(input_remapped, &imported_node_def); + AddPrefixToNodeDef(input_already_exists, &imported_node_def); TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def)); } TF_RETURN_IF_ERROR(MakeNode(*node_def, &node)); diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 0cbd6fa9244..61704913c3d 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -89,6 +89,14 @@ struct ImportGraphDefOptions { // TODO(skyewm): add functionality to retrieve unused `input_map` keys std::map<TensorId, TensorId> input_map; + // The names of existing nodes in `g` that the imported graph should have + // control dependencies on. + // + // Note that to avoid creating many redundant control edges, ImportGraphDef() + // won't add control edges to nodes that will inherit the dependencies from + // other nodes in `gdef`. + std::vector<string> control_dependencies; + // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries // with ops that are not defined in the binary calling ImportGraphDef. // Similar to the producer_op_list argument to import_graph_def in the diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 2d16ddb66f1..a173d3a6275 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -117,8 +117,9 @@ class GraphConstructorTest : public ::testing::Test { bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) { for (const Edge* e : graph_.edges()) { if (e->src()->name() == src && e->src_output() == src_out && - e->dst()->name() == dst && e->dst_input() == dst_in) + e->dst()->name() == dst && e->dst_input() == dst_in) { return true; + } } return false; } @@ -1198,6 +1199,133 @@ versions { EXPECT_EQ(Status::OK(), s) << s; } +TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) { + ShapeRefiner refiner(graph_.op_registry()); + + // Populate graph with nodes we'll use in control deps and input map + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }", + ImportGraphDefOptions(), &refiner); + + ImportGraphDefOptions opts; + opts.control_dependencies = {"W1", "W2"}; + opts.prefix = "import"; + opts.input_map[TensorId("W1", -1)] = TensorId("W1", -1); + ExpectOK( + R"EOF( + node { name: 'W1' op: 'TestParams' } + node { name: 'input' op: 'TestInput' } + node { name: 'input2' op: 'TestInput' input: [ '^W1' ] } + node { name: 't1' op: 'TestMul' input: [ 'input:0', 'input:1' ] } + )EOF", + opts, &refiner); + + // Sanity checks + EXPECT_TRUE(HasNode("import/W1")); + EXPECT_TRUE(HasNode("import/input")); + EXPECT_TRUE(HasNode("import/input2")); + EXPECT_TRUE(HasNode("import/t1")); + + EXPECT_TRUE(HasControlEdge("W1", "import/W1")); + EXPECT_TRUE(HasControlEdge("W2", "import/W1")); + + EXPECT_TRUE(HasControlEdge("W1", "import/input")); + EXPECT_TRUE(HasControlEdge("W2", "import/input")); + + // Test that t1 doesn't have redundant control edges + EXPECT_FALSE(HasControlEdge("W1", "import/t1")); + EXPECT_FALSE(HasControlEdge("W2", "import/t1")); + EXPECT_TRUE(HasEdge("import/input", 0, "import/t1", 0)); + EXPECT_TRUE(HasEdge("import/input", 1, "import/t1", 1)); + + // Test that input2 has control edges since its only input was remapped + EXPECT_TRUE(HasControlEdge("W1", "import/input2")); + EXPECT_TRUE(HasControlEdge("W2", "import/input2")); + EXPECT_FALSE(HasControlEdge("import/W1", "import/input2")); + + // Test that node defs are consistent with graph + Node* w1 = FindNode("import/W1"); + ASSERT_EQ(w1->def().input_size(), 2); + EXPECT_EQ(w1->def().input(0), "^W1"); + EXPECT_EQ(w1->def().input(1), "^W2"); + + Node* input = FindNode("import/input"); + ASSERT_EQ(input->def().input_size(), 2); + EXPECT_EQ(input->def().input(0), "^W1"); + EXPECT_EQ(input->def().input(1), "^W2"); + + Node* input2 = FindNode("import/input2"); + ASSERT_EQ(input2->def().input_size(), 2); + EXPECT_EQ(input2->def().input(0), "^W1"); + EXPECT_EQ(input2->def().input(1), "^W2"); + + Node* t1 = FindNode("import/t1"); + ASSERT_EQ(t1->def().input_size(), 2); + EXPECT_EQ(t1->def().input(0), "import/input:0"); + EXPECT_EQ(t1->def().input(1), "import/input:1"); +} + +TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { + ShapeRefiner refiner(graph_.op_registry()); + + // Populate graph with nodes we'll use in control deps and input map + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }", + ImportGraphDefOptions(), &refiner); + + ImportGraphDefOptions opts; + opts.control_dependencies.push_back("W1"); + // Use input_map to ensure the cycle doesn't inherit the control deps from + // new_input + opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0); + + // ImportGraphDef only allows backedges into merge nodes (since backedges are + // only expected in while loops) + ExpectOK( + R"EOF( + node { name: 'new_input' op: 'TestInput' } + node { name: 'merge' op: 'Merge' input: [ 'new_input:0', 't1:0' ] + attr { key: "N" value: { i: 2 } } + attr { key: "T" value: { type: DT_FLOAT } } } + node { name: 't1' op: 'TestMul' input: [ 'merge:0', 'merge:0' ] } + )EOF", + opts, &refiner); + + EXPECT_TRUE(HasNode("new_input")); + EXPECT_TRUE(HasNode("merge")); + EXPECT_TRUE(HasNode("t1")); + + // Sanity check we created cycle + EXPECT_TRUE(HasEdge("merge", 0, "t1", 0)); + EXPECT_TRUE(HasEdge("t1", 0, "merge", 1)); + + // Test that control dep was added to exactly one node of cycle + EXPECT_TRUE(HasControlEdge("W1", "merge")); + EXPECT_FALSE(HasControlEdge("W1", "t1")); + + // Test that node defs are consistent with graph + Node* merge = FindNode("merge"); + ASSERT_EQ(merge->def().input_size(), 3); + EXPECT_EQ(merge->def().input(0), "input:0"); + EXPECT_EQ(merge->def().input(1), "t1:0"); + EXPECT_EQ(merge->def().input(2), "^W1"); + + Node* t1 = FindNode("t1"); + ASSERT_EQ(t1->def().input_size(), 2); + EXPECT_EQ(t1->def().input(0), "merge:0"); + EXPECT_EQ(t1->def().input(1), "merge:0"); +} + +TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) { + // Control dep that isn't in graph def + ImportGraphDefOptions opts; + opts.control_dependencies.push_back("W1"); + ExpectError("node { name: 'W1' op: 'TestParams' }", opts, + {"node 'W1' in control_dependencies does not exist in graph"}); +} + TEST_F(GraphConstructorTest, ImportGraphDef_ErrorsDoNoChangeTheGraph) { GraphDef def; NodeDefBuilder("scope/A", "TestParams").Finalize(def.add_node()); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 014a40fef18..f754e70c7de 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3721,6 +3721,8 @@ filegroup( # See b/29213790 "scatter_nd_op*", "sparse_matmul_op.*", + # Lib CURL is not supported on Android. + "bigquery*", ], ), visibility = ["//visibility:public"], diff --git a/tensorflow/core/kernels/cloud/BUILD b/tensorflow/core/kernels/cloud/BUILD index dfb4772b97e..710cb5aa14b 100644 --- a/tensorflow/core/kernels/cloud/BUILD +++ b/tensorflow/core/kernels/cloud/BUILD @@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "tf_kernel_library", "tf_cc_test", ) @@ -30,6 +31,24 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +tf_kernel_library( + name = "bigquery_reader_ops", + srcs = [ + "bigquery_reader_ops.cc", + ], + visibility = ["//visibility:public"], + deps = [ + ":bigquery_table_accessor", + ":bigquery_table_partition_proto_cc", + "//tensorflow/core:cloud_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:reader_base", + ], +) + cc_library( name = "bigquery_table_accessor", srcs = [ diff --git a/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc b/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc new file mode 100644 index 00000000000..a3b026e2a15 --- /dev/null +++ b/tensorflow/core/kernels/cloud/bigquery_reader_ops.cc @@ -0,0 +1,193 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <map> +#include <memory> +#include <set> + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h" +#include "tensorflow/core/kernels/cloud/bigquery_table_partition.pb.h" +#include "tensorflow/core/kernels/reader_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { +namespace { + +constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer. + +// This is a helper function for reading table attributes from context. +Status GetTableAttrs(OpKernelConstruction* context, string* project_id, + string* dataset_id, string* table_id, + int64* timestamp_millis, std::vector<string>* columns, + string* test_end_point) { + TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id)); + TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id)); + TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id)); + TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis)); + TF_RETURN_IF_ERROR(context->GetAttr("columns", columns)); + TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point)); + return Status::OK(); +} + +} // namespace + +// Note that overriden methods with names ending in "Locked" are called by +// ReaderBase while a mutex is held. +// See comments for ReaderBase. +class BigQueryReader : public ReaderBase { + public: + explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor, + const string& node_name) + : ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")), + bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {} + + Status OnWorkStartedLocked() override { + BigQueryTablePartition partition; + if (!partition.ParseFromString(current_work())) { + return errors::InvalidArgument( + "Could not parse work as as valid partition."); + } + TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition)); + return Status::OK(); + } + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + *at_end = false; + *produced = false; + if (bigquery_table_accessor_->Done()) { + *at_end = true; + return Status::OK(); + } + + Example example; + int64 row_id; + TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example)); + + *key = std::to_string(row_id); + *value = example.SerializeAsString(); + *produced = true; + return Status::OK(); + } + + private: + // Not owned. + BigQueryTableAccessor* bigquery_table_accessor_; +}; + +class BigQueryReaderOp : public ReaderOpKernel { + public: + explicit BigQueryReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + string table_id; + string project_id; + string dataset_id; + int64 timestamp_millis; + std::vector<string> columns; + string test_end_point; + + OP_REQUIRES_OK(context, + GetTableAttrs(context, &project_id, &dataset_id, &table_id, + ×tamp_millis, &columns, &test_end_point)); + OP_REQUIRES_OK(context, + BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, + kDefaultRowBufferSize, test_end_point, columns, + BigQueryTablePartition(), &bigquery_table_accessor_)); + + SetReaderFactory([this]() { + return new BigQueryReader(bigquery_table_accessor_.get(), name()); + }); + } + + private: + std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_; +}; + +REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU), + BigQueryReaderOp); + +class GenerateBigQueryReaderPartitionsOp : public OpKernel { + public: + explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context) + : OpKernel(context) { + string project_id; + string dataset_id; + string table_id; + int64 timestamp_millis; + std::vector<string> columns; + string test_end_point; + + OP_REQUIRES_OK(context, + GetTableAttrs(context, &project_id, &dataset_id, &table_id, + ×tamp_millis, &columns, &test_end_point)); + OP_REQUIRES_OK(context, + BigQueryTableAccessor::New( + project_id, dataset_id, table_id, timestamp_millis, + kDefaultRowBufferSize, test_end_point, columns, + BigQueryTablePartition(), &bigquery_table_accessor_)); + OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context)); + OP_REQUIRES_OK(context, InitializeTotalNumberOfRows()); + } + + void Compute(OpKernelContext* context) override { + const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>( + total_num_rows_, num_partitions_); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_partitions_}), + &output_tensor)); + + auto output = output_tensor->template flat<string>(); + for (int64 i = 0; i < num_partitions_; ++i) { + BigQueryTablePartition partition; + partition.set_start_index(i * partition_size); + partition.set_end_index( + std::min(total_num_rows_, (i + 1) * partition_size) - 1); + output(i) = partition.SerializeAsString(); + } + } + + private: + Status InitializeTotalNumberOfRows() { + total_num_rows_ = bigquery_table_accessor_->total_num_rows(); + if (total_num_rows_ <= 0) { + return errors::FailedPrecondition("Invalid total number of rows."); + } + return Status::OK(); + } + + Status InitializeNumberOfPartitions(OpKernelConstruction* context) { + TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_)); + if (num_partitions_ <= 0) { + return errors::FailedPrecondition("Invalid number of partitions."); + } + return Status::OK(); + } + + int64 num_partitions_; + int64 total_num_rows_; + std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_; +}; + +REGISTER_KERNEL_BUILDER( + Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU), + GenerateBigQueryReaderPartitionsOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc b/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc index 293d47d9755..3e9adfa3727 100644 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc +++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h" #include "tensorflow/core/example/feature.pb.h" @@ -23,6 +22,15 @@ namespace tensorflow { namespace { constexpr size_t kBufferSize = 1024 * 1024; // In bytes. +const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2"; + +bool IsPartitionEmpty(const BigQueryTablePartition& partition) { + if (partition.end_index() != -1 && + partition.end_index() < partition.start_index()) { + return true; + } + return false; +} Status ParseJson(StringPiece json, Json::Value* result) { Json::Reader reader; @@ -92,17 +100,18 @@ Status ParseColumnType(const string& type, Status BigQueryTableAccessor::New( const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, - const std::set<string>& columns, const BigQueryTablePartition& partition, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition, std::unique_ptr<BigQueryTableAccessor>* accessor) { return New(project_id, dataset_id, table_id, timestamp_millis, - row_buffer_size, columns, partition, nullptr, nullptr, accessor); + row_buffer_size, end_point, columns, partition, nullptr, nullptr, + accessor); } Status BigQueryTableAccessor::New( const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, - const std::set<string>& columns, const BigQueryTablePartition& partition, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition, std::unique_ptr<AuthProvider> auth_provider, std::unique_ptr<HttpRequest::Factory> http_request_factory, std::unique_ptr<BigQueryTableAccessor>* accessor) { @@ -110,14 +119,16 @@ Status BigQueryTableAccessor::New( return errors::InvalidArgument( "Cannot use zero or negative timestamp to query a table."); } + const string& big_query_end_point = + end_point.empty() ? kBigQueryEndPoint : end_point; if (auth_provider == nullptr && http_request_factory == nullptr) { - accessor->reset(new BigQueryTableAccessor(project_id, dataset_id, table_id, - timestamp_millis, row_buffer_size, - columns, partition)); + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition)); } else { accessor->reset(new BigQueryTableAccessor( project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - columns, partition, std::move(auth_provider), + big_query_end_point, columns, partition, std::move(auth_provider), std::move(http_request_factory))); } return (*accessor)->ReadSchema(); @@ -125,11 +136,11 @@ Status BigQueryTableAccessor::New( BigQueryTableAccessor::BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, - const std::set<string>& columns, const BigQueryTablePartition& partition) + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition) : BigQueryTableAccessor( project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - columns, partition, + end_point, columns, partition, std::unique_ptr<AuthProvider>(new GoogleAuthProvider()), std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory())) { row_buffer_.resize(row_buffer_size); @@ -137,15 +148,16 @@ BigQueryTableAccessor::BigQueryTableAccessor( BigQueryTableAccessor::BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, - const std::set<string>& columns, const BigQueryTablePartition& partition, + int64 timestamp_millis, int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition, std::unique_ptr<AuthProvider> auth_provider, std::unique_ptr<HttpRequest::Factory> http_request_factory) : project_id_(project_id), dataset_id_(dataset_id), table_id_(table_id), timestamp_millis_(timestamp_millis), - columns_(columns), + columns_(columns.begin(), columns.end()), + bigquery_end_point_(end_point), partition_(partition), auth_provider_(std::move(auth_provider)), http_request_factory_(std::move(http_request_factory)) { @@ -153,10 +165,14 @@ BigQueryTableAccessor::BigQueryTableAccessor( Reset(); } -void BigQueryTableAccessor::SetPartition( +Status BigQueryTableAccessor::SetPartition( const BigQueryTablePartition& partition) { + if (partition.start_index() < 0) { + return errors::InvalidArgument("Start index cannot be negative."); + } partition_ = partition; Reset(); + return Status::OK(); } void BigQueryTableAccessor::Reset() { @@ -172,7 +188,8 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { // If the next row is already fetched and cached, return the row from the // buffer. Otherwise, fill up the row buffer from BigQuery and return a row. - if (next_row_in_buffer_ != -1 && next_row_in_buffer_ < row_buffer_.size()) { + if (next_row_in_buffer_ != -1 && + next_row_in_buffer_ < ComputeMaxResultsArg()) { *row_id = first_buffered_row_index_ + next_row_in_buffer_; *example = row_buffer_[next_row_in_buffer_]; next_row_in_buffer_++; @@ -190,12 +207,12 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { // we use the page token (which returns rows faster). if (!next_page_token_.empty()) { TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( - BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(), + BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), "&pageToken=", request->EscapeString(next_page_token_)))); first_buffered_row_index_ += row_buffer_.size(); } else { TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( - BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(), + BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), "&startIndex=", first_buffered_row_index_))); } TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); @@ -222,6 +239,18 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { return Status::OK(); } +int64 BigQueryTableAccessor::ComputeMaxResultsArg() { + if (partition_.end_index() == -1) { + return row_buffer_.size(); + } + if (IsPartitionEmpty(partition_)) { + return 0; + } + return std::min(static_cast<int64>(row_buffer_.size()), + static_cast<int64>(partition_.end_index() - + partition_.start_index() + 1)); +} + Status BigQueryTableAccessor::ParseColumnValues( const Json::Value& value, const SchemaNode& root_schema_node, Example* example) { @@ -364,21 +393,17 @@ Status BigQueryTableAccessor::AppendValueToExample( string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() { HttpRequest request; - return strings::StrCat("https://www.googleapis.com/bigquery/v2/projects/", + return strings::StrCat(bigquery_end_point_, "/projects/", request.EscapeString(project_id_), "/datasets/", request.EscapeString(dataset_id_), "/tables/", request.EscapeString(table_id_), "/"); } -string BigQueryTableAccessor::FullTableName() { - return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@", - timestamp_millis_); -} - bool BigQueryTableAccessor::Done() { return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) || + IsPartitionEmpty(partition_) || (partition_.end_index() != -1 && - partition_.end_index() <= + partition_.end_index() < first_buffered_row_index_ + next_row_in_buffer_); } diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor.h b/tensorflow/core/kernels/cloud/bigquery_table_accessor.h index fafda9cdd6f..33d1905b8a9 100644 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor.h +++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor.h @@ -55,16 +55,23 @@ class BigQueryTableAccessor { }; /// \brief Creates a new BigQueryTableAccessor object. + // + // We do not allow relative (negative or zero) snapshot times here since we + // want to have a consistent snapshot of the table for the lifetime of this + // object. + // Use end_point if you want to connect to a different end point than the + // official BigQuery end point. Otherwise send an empty string. static Status New(const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const std::set<string>& columns, + int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition, std::unique_ptr<BigQueryTableAccessor>* accessor); /// \brief Starts reading a new partition. - void SetPartition(const BigQueryTablePartition& partition); + Status SetPartition(const BigQueryTablePartition& partition); - /// \brief Returns false if there are more rows available in the current + /// \brief Returns true if there are more rows available in the current /// partition. bool Done(); @@ -74,9 +81,11 @@ class BigQueryTableAccessor { /// in the BigQuery service. Status ReadRow(int64* row_id, Example* example); - /// \brief Returns total number of rows. + /// \brief Returns total number of rows in the table. int64 total_num_rows() { return total_num_rows_; } + virtual ~BigQueryTableAccessor() {} + private: friend class BigQueryTableAccessorTest; @@ -95,7 +104,8 @@ class BigQueryTableAccessor { /// these two variables. static Status New(const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const std::set<string>& columns, + int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition, std::unique_ptr<AuthProvider> auth_provider, std::unique_ptr<HttpRequest::Factory> http_request_factory, @@ -104,14 +114,16 @@ class BigQueryTableAccessor { /// \brief Constructs an object for a given table and partition. BigQueryTableAccessor(const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const std::set<string>& columns, + int64 row_buffer_size, const string& end_point, + const std::vector<string>& columns, const BigQueryTablePartition& partition); /// Used for unit testing. BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, int64 row_buffer_size, - const std::set<string>& columns, const BigQueryTablePartition& partition, + const string& end_point, const std::vector<string>& columns, + const BigQueryTablePartition& partition, std::unique_ptr<AuthProvider> auth_provider, std::unique_ptr<HttpRequest::Factory> http_request_factory); @@ -132,7 +144,7 @@ class BigQueryTableAccessor { Status AppendValueToExample(const string& column_name, const Json::Value& column_value, const BigQueryTableAccessor::ColumnType type, - Example* ex); + Example* example); /// \brief Resets internal counters for reading a partition. void Reset(); @@ -140,25 +152,28 @@ class BigQueryTableAccessor { /// \brief Helper function that returns BigQuery http endpoint prefix. string BigQueryUriPrefix(); + /// \brief Computes the maxResults arg to send to BigQuery. + int64 ComputeMaxResultsArg(); + /// \brief Returns full name of the underlying table name. - string FullTableName(); + string FullTableName() { + return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@", + timestamp_millis_); + } const string project_id_; const string dataset_id_; const string table_id_; // Snapshot timestamp. - // - // Indicates a snapshot of the table in milliseconds since the epoch. - // - // We do not allow relative (negative or zero) times here since we want to - // have a consistent snapshot of the table for the lifetime of this object. - // For more details, see 'Table Decorators' in BigQuery documentation. const int64 timestamp_millis_; // Columns that should be read. Empty means all columns. const std::set<string> columns_; + // HTTP address of BigQuery end point to use. + const string bigquery_end_point_; + // Describes the portion of the table that we are currently accessing. BigQueryTablePartition partition_; diff --git a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc b/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc index 306cc5a4e10..57a4b892518 100644 --- a/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc +++ b/tensorflow/core/kernels/cloud/bigquery_table_accessor_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { - namespace { constexpr char kTestProject[] = "test-project"; @@ -69,10 +68,10 @@ class BigQueryTableAccessorTest : public ::testing::Test { Status CreateTableAccessor(const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, int64 row_buffer_size, - const std::set<string>& columns, + const std::vector<string>& columns, const BigQueryTablePartition& partition) { return BigQueryTableAccessor::New( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "", columns, partition, std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( new FakeHttpRequestFactory(&requests_)), @@ -197,7 +196,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) { kTestRow)); BigQueryTablePartition partition; partition.set_start_index(2); - partition.set_end_index(3); + partition.set_end_index(2); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, {}, partition)); int64 row_id; @@ -227,7 +226,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) { kTestRow)); BigQueryTablePartition partition; partition.set_start_index(2); - partition.set_end_index(3); + partition.set_end_index(2); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, {"bool_field", "rec_field.float_field"}, partition)); @@ -258,7 +257,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) { kTestRowWithNulls)); BigQueryTablePartition partition; partition.set_start_index(2); - partition.set_end_index(3); + partition.set_end_index(2); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, {}, partition)); int64 row_id; @@ -288,7 +287,7 @@ TEST_F(BigQueryTableAccessorTest, BrokenRowTest) { kBrokenTestRow)); BigQueryTablePartition partition; partition.set_start_index(2); - partition.set_end_index(3); + partition.set_end_index(2); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, {}, partition)); int64 row_id; @@ -357,7 +356,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { kSampleSchema)); requests_.emplace_back(new FakeHttpRequest( "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" - "datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n" + "datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n" "Auth Token: fake_token\n", kTestTwoRows)); requests_.emplace_back(new FakeHttpRequest( @@ -374,7 +373,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { BigQueryTablePartition partition; partition.set_start_index(0); - partition.set_end_index(1); + partition.set_end_index(0); TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2, {}, partition)); @@ -396,7 +395,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { 1234); partition.set_start_index(0); - partition.set_end_index(2); + partition.set_end_index(1); accessor_->SetPartition(partition); TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); EXPECT_EQ(0, row_id); @@ -410,4 +409,23 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) { 2222); } +TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) { + requests_.emplace_back(new FakeHttpRequest( + "Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/" + "datasets/test-dataset/tables/test-table/\n" + "Auth Token: fake_token\n", + kSampleSchema)); + + BigQueryTablePartition partition; + partition.set_start_index(3); + partition.set_end_index(2); + TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1, + {}, partition)); + EXPECT_TRUE(accessor_->Done()); + + int64 row_id; + Example example; + EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example))); +} + } // namespace tensorflow diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index a56524f3697..5f1f5b652c1 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -24,48 +24,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { - typedef Eigen::ThreadPoolDevice CPUDevice; -class InplaceOpBase : public OpKernel { - public: - explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - auto value = ctx->input(0); - auto loc = ctx->input(1); - auto update = ctx->input(2); - - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()), - errors::InvalidArgument("loc must be a vector. ", - loc.shape().DebugString())); - OP_REQUIRES( - ctx, value.dims() == update.dims(), - errors::InvalidArgument("value and update shape doesn't match: ", - value.shape().DebugString(), " vs. ", - update.shape().DebugString())); - for (int i = 1; i < value.dims(); ++i) { - OP_REQUIRES( - ctx, value.dim_size(i) == update.dim_size(i), - errors::InvalidArgument("value and update shape doesn't match ", - value.shape().DebugString(), " vs. ", - update.shape().DebugString())); - } - OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0), - errors::InvalidArgument("loc and update shape doesn't match: ", - loc.shape().DebugString(), " vs. ", - update.shape().DebugString())); - - Tensor output = value; // This creates an alias intentionally. - OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output)); - ctx->set_output(0, output); - } - - protected: - virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value, - const Tensor& loc, Tensor* output) = 0; -}; - namespace functor { template <typename T> @@ -111,6 +71,48 @@ Status DoInplace(const CPUDevice& d, InplaceOpType op, const Tensor& value, } // end namespace functor +namespace { + +// TODO(apassos): validate the shapes better. +class InplaceOpBase : public OpKernel { + public: + explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + auto value = ctx->input(0); + auto loc = ctx->input(1); + auto update = ctx->input(2); + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()), + errors::InvalidArgument("loc must be a vector. ", + loc.shape().DebugString())); + OP_REQUIRES( + ctx, value.dims() == update.dims(), + errors::InvalidArgument("value and update shape doesn't match: ", + value.shape().DebugString(), " vs. ", + update.shape().DebugString())); + for (int i = 1; i < value.dims(); ++i) { + OP_REQUIRES( + ctx, value.dim_size(i) == update.dim_size(i), + errors::InvalidArgument("value and update shape doesn't match ", + value.shape().DebugString(), " vs. ", + update.shape().DebugString())); + } + OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0), + errors::InvalidArgument("loc and update shape doesn't match: ", + loc.shape().DebugString(), " vs. ", + update.shape().DebugString())); + + Tensor output = value; // This creates an alias intentionally. + OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output)); + ctx->set_output(0, output); + } + + protected: + virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value, + const Tensor& loc, Tensor* output) = 0; +}; + template <typename Device, functor::InplaceOpType op> class InplaceOp : public InplaceOpBase { public: @@ -159,21 +161,27 @@ class EmptyOp : public OpKernel { bool init_; }; -#define REGISTER(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("InplaceUpdate").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - InplaceOp<CPUDevice, functor::I_UPDATE>); \ - REGISTER_KERNEL_BUILDER( \ - Name("InplaceAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - InplaceOp<CPUDevice, functor::I_ADD>); \ - REGISTER_KERNEL_BUILDER( \ - Name("InplaceSubtract").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ - InplaceOp<CPUDevice, functor::I_SUB>); +class FailureKernel : public OpKernel { + public: + explicit FailureKernel(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, + errors::Internal("Found instance of parallel_stack which " + "could not be properly replaced.")); + } + + void Compute(OpKernelContext*) {} +}; + +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T"), \ + InplaceOp<CPUDevice, functor::I_UPDATE>); TF_CALL_NUMBER_TYPES(REGISTER) #undef REGISTER #define REGISTER_EMPTY(type) \ - REGISTER_KERNEL_BUILDER(Name("Empty") \ + REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ .Device(DEVICE_CPU) \ .HostMemory("shape") \ .TypeConstraint<type>("dtype"), \ @@ -182,12 +190,19 @@ TF_CALL_NUMBER_TYPES(REGISTER) TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY) #undef REGISTER_EMPTY +#define REGISTER_PARALLEL_CONCAT(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ParallelConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + FailureKernel); +TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT); +#undef REGISTER_PARALLEL_CONCAT + #if GOOGLE_CUDA typedef Eigen::GpuDevice GPUDevice; #define REGISTER_EMPTY(type) \ - REGISTER_KERNEL_BUILDER(Name("Empty") \ + REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ .Device(DEVICE_GPU) \ .HostMemory("shape") \ .TypeConstraint<type>("dtype"), \ @@ -195,23 +210,25 @@ typedef Eigen::GpuDevice GPUDevice; TF_CALL_GPU_NUMBER_TYPES(REGISTER_EMPTY) #undef REGISTER_EMPTY -#define REGISTER(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - InplaceOp<GPUDevice, functor::I_UPDATE>); \ - REGISTER_KERNEL_BUILDER( \ - Name("InplaceAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - InplaceOp<GPUDevice, functor::I_ADD>); \ - REGISTER_KERNEL_BUILDER( \ - Name("InplaceSubtract").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ - InplaceOp<GPUDevice, functor::I_SUB>); +#define REGISTER_PARALLEL_CONCAT(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ParallelConcat").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + FailureKernel); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT); +#undef REGISTER_PARALLEL_CONCAT + +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T"), \ + InplaceOp<GPUDevice, functor::I_UPDATE>); TF_CALL_GPU_NUMBER_TYPES(REGISTER) #undef REGISTER // Register versions that operate on int32 data on the CPU even though the op // has been placed on the GPU -REGISTER_KERNEL_BUILDER(Name("InplaceUpdate") +REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") .Device(DEVICE_GPU) .HostMemory("value") .HostMemory("loc") @@ -219,24 +236,7 @@ REGISTER_KERNEL_BUILDER(Name("InplaceUpdate") .HostMemory("output") .TypeConstraint<int32>("T"), InplaceOp<CPUDevice, functor::I_UPDATE>); - -REGISTER_KERNEL_BUILDER(Name("InplaceAdd") - .Device(DEVICE_GPU) - .HostMemory("value") - .HostMemory("loc") - .HostMemory("update") - .HostMemory("output") - .TypeConstraint<int32>("T"), - InplaceOp<CPUDevice, functor::I_ADD>); - -REGISTER_KERNEL_BUILDER(Name("InplaceSubtract") - .Device(DEVICE_GPU) - .HostMemory("value") - .HostMemory("loc") - .HostMemory("update") - .HostMemory("output") - .TypeConstraint<int32>("T"), - InplaceOp<CPUDevice, functor::I_SUB>); #endif +} // end namespace } // end namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index ad35209f359..7ce667675d5 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -97,10 +97,9 @@ Status PadShapeFn(InferenceContext* c) { return Status::OK(); } - // tensor value was provided for paddings_t; doublecheck n_dim value is the - // same. - const auto num_dims = c->Value(n_dim); - DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0)); + const int64 num_dims = paddings_t->shape().dim_size(0); + TF_RETURN_IF_ERROR(c->WithRank(input, num_dims, &input)); + TF_RETURN_IF_ERROR(c->WithValue(n_dim, num_dims, &n_dim)); if (paddings_t->dtype() == DT_INT32) { return PadKnown<int32>(c, input, paddings_t, num_dims); @@ -165,6 +164,71 @@ Status SetOutputShapeForReshape(InferenceContext* c) { } // namespace +REGISTER_OP("ParallelConcat") + .Input("values: N * T") + .Output("output: T") + .Attr("N: int >= 1") + .Attr("T: type") + .Attr("shape: shape") + .SetShapeFn([](InferenceContext* c) { + // Validate that the shape attr is correct. + TensorShapeProto passed_shape_proto; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &passed_shape_proto)); + ShapeHandle passed_shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeProto(passed_shape_proto, &passed_shape)); + if (!c->FullyDefined(passed_shape)) { + return errors::InvalidArgument("shape attr must be fully defined."); + } + ShapeHandle cur; + TF_RETURN_IF_ERROR(c->ReplaceDim( + passed_shape, 0, c->MakeDim(shape_inference::DimensionOrConstant(1)), + &cur)); + for (int i = 0; i < c->num_inputs(); ++i) { + if (!c->FullyDefined(c->input(i))) { + return errors::InvalidArgument( + "All input shapes must be fully defined."); + } + DimensionHandle unused; + if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) { + return errors::InvalidArgument("Size of first dimension must be 1."); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), + "From merging shape ", i, + " with other shapes."); + } + + c->set_output(0, passed_shape); + + return Status::OK(); + }) + .Doc(R"doc( +Concatenates a list of `N` tensors along the first dimension. + +The input tensors are all required to have size 1 in the first dimension. + +For example: + +```prettyprint +# 'x' is [[1, 4]] +# 'y' is [[2, 5]] +# 'z' is [[3, 6]] +parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +``` + +The difference between concat and parallel_concat is that concat requires all +of the inputs be computed before the operation will begin but doesn't require +that the input shapes be known during graph construction. Parallel concat +will copy pieces of the input into the output as they become available, in +some situations this can provide a performance benefit. + +values: Tensors to be concatenated. All must have size 1 in the first dimension + and same shape. +output: The concatenated tensor. +shape: the final shape of the result; should be equal to the shapes of any input + but with the number of input values in the first dimension. +)doc"); + REGISTER_OP("Pack") .Input("values: N * T") .Output("output: T") @@ -440,7 +504,8 @@ REGISTER_OP("SplitV") c->set_output(i, output_shape); } } else { - // Determine the output shape if split dimension and split sizes are known + // Determine the output shape if split dimension and split sizes are + // known. int64 split_dim = c->Value(split_dimension); TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); std::vector<int64> data; @@ -451,12 +516,12 @@ REGISTER_OP("SplitV") } if (num_outputs != data.size()) { return errors::InvalidArgument( - "Length of size_splits should be equal to num_outputs"); + "Length of size_splits should be equal to num_outputs"); } for (int i = 0; i < num_outputs; ++i) { output_shape = c->UnknownShapeOfRank(rank); - TF_RETURN_IF_ERROR( - c->ReplaceDim(input, split_dim, c->MakeDim(data[i]), &output_shape)); + TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim, + c->MakeDim(data[i]), &output_shape)); c->set_output(i, output_shape); } } @@ -1160,7 +1225,7 @@ Equivalent to np.full )doc"); // -------------------------------------------------------------------------- -REGISTER_OP("Empty") +REGISTER_OP("_ParallelConcatStart") .Input("shape: Tshape") .Output("output: dtype") .Attr("dtype: type") @@ -1186,7 +1251,7 @@ output: An empty Tensor of the specified type. )doc"); // -------------------------------------------------------------------------- -REGISTER_OP("InplaceUpdate") +REGISTER_OP("_ParallelConcatUpdate") .Input("value: T") .Input("loc: Tshape") .Input("update: T") @@ -1225,86 +1290,6 @@ update: A `Tensor` of rank one less than `value` if `loc` is a scalar, output: `value` that has been updated accordingly. )doc"); -// -------------------------------------------------------------------------- -REGISTER_OP("InplaceAdd") - .Input("value: T") - .Input("loc: Tshape") - .Input("update: T") - .Output("output: T") - .Attr("T: type") - .Attr("Tshape: {int32, int64} = DT_INT32") - .SetShapeFn(shape_inference::UnchangedShape) - .Doc(R"doc( -Updates input `value` at `loc` by adding `update` elementwise. - -If `loc` is None, `value` and `update` must be the same size. -``` -value += update -``` - -If `loc` is a scalar, `value` has rank 1 higher than `update` -``` -value[i, :] += update -``` - -If `loc` is a vector, `value` has the same rank as `update` -``` -value[loc, :] += update -``` - -If you use this function you will almost certainly want to add -a control dependency as done in the implementation of parallel_stack to -avoid race conditions. - -value: A `Tensor` object that will be updated in-place. -loc: A scalar or 1-D `Tensor` indicating the indices of the first dimension - such that value[loc, :] is updated. -update: A `Tensor` of rank one less than `value` if `loc` is a scalar, - otherwise of rank equal to `value` that contains the new values - that will be added to `value`. -output: `value` where `update` has been added as appropriate. -)doc"); - -// -------------------------------------------------------------------------- -REGISTER_OP("InplaceSubtract") - .Input("value: T") - .Input("loc: Tshape") - .Input("update: T") - .Output("output: T") - .Attr("T: type") - .Attr("Tshape: {int32, int64} = DT_INT32") - .SetShapeFn(shape_inference::UnchangedShape) - .Doc(R"doc( -Updates input `value` at `loc` by subtracting `update` elementwise. - -If `loc` is None, `value` and `update` must be the same size. -``` -value -= update -``` - -If `loc` is a scalar, `value` has rank 1 higher than `update` -``` -value[i, :] -= update -``` - -If `loc` is a vector, `value` has the same rank as `update` -``` -value[loc, :] -= update -``` - -If you use this function you will almost certainly want to add -a control dependency as done in the implementation of parallel_stack to -avoid race conditions. - -value: A `Tensor` object that will be updated in-place. -loc: A scalar or 1-D `Tensor` indicating the indices of the first dimension - such that value[loc, :] is updated. -update: A `Tensor` of rank one less than `value` if `loc` is a scalar, - otherwise of rank equal to `value` that contains the new values - that will be subtracted from `value`. -output: `value` where `update` has been subtracted as appropriate. -)doc"); - // -------------------------------------------------------------------------- REGISTER_OP("Gather") .Input("params: Tparams") @@ -1370,8 +1355,8 @@ REGISTER_OP("GatherNd") if (c->Value(r_dim) > c->Rank(params)) { return errors::InvalidArgument( "indices.shape[-1] must be <= params.rank, but saw indices shape: ", - c->DebugString(indices), " and params shape: ", - c->DebugString(params)); + c->DebugString(indices), + " and params shape: ", c->DebugString(params)); } // Remove r_dim from indices to get output. @@ -1906,12 +1891,12 @@ REGISTER_OP("ReverseSequence") // Validate batch_dim and seq_dim against input. const int32 input_rank = c->Rank(input); if (batch_dim >= input_rank) { - return errors::InvalidArgument("batch_dim must be < input rank: ", - batch_dim, " vs. ", input_rank); + return errors::InvalidArgument( + "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank); } if (seq_dim >= input_rank) { - return errors::InvalidArgument("seq_dim must be < input rank: ", - seq_dim, " vs. ", input_rank); + return errors::InvalidArgument( + "seq_dim must be < input rank: ", seq_dim, " vs. ", input_rank); } DimensionHandle batch_dim_dim = c->Dim(input, batch_dim); @@ -3790,8 +3775,9 @@ REGISTER_OP("SpaceToDepth") TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, 3), block_size * block_size, &output_depth)); - c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height, - output_width, output_depth})); + c->set_output(0, + c->MakeShape({c->Dim(input, 0), output_height, output_width, + output_depth})); return Status::OK(); }) .Doc(R"doc( @@ -3895,8 +3881,9 @@ REGISTER_OP("DepthToSpace") TF_RETURN_IF_ERROR(c->Divide(c->Dim(input, 3), block_size * block_size, true /* evenly_divisible */, &output_depth)); - c->set_output(0, c->MakeShape({c->Dim(input, 0), output_height, - output_width, output_depth})); + c->set_output(0, + c->MakeShape({c->Dim(input, 0), output_height, output_width, + output_depth})); return Status::OK(); }) .Doc(R"doc( @@ -4772,8 +4759,9 @@ Status ScatterNdShape(InferenceContext* c) { Status s = c->Merge(prefix_indices, prefix_updates, &unused); if (!s.ok()) { return errors::InvalidArgument( - "The outer ", outer_dims, " dimensions of indices.shape=", - c->DebugString(indices_shape), " must match the outer ", outer_dims, + "The outer ", outer_dims, + " dimensions of indices.shape=", c->DebugString(indices_shape), + " must match the outer ", outer_dims, " dimensions of updates.shape=", c->DebugString(updates_shape), ": ", s.error_message()); } diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 03cdc7dc65a..04f268accdc 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -331,6 +331,7 @@ TEST(ArrayOpsTest, PadD_ShapeFn) { INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]"); INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]"); INFER_OK(op, "?;[3,2]", "[?,?,?]"); + INFER_OK(op, "?;?", "[?,?,?]"); } } diff --git a/tensorflow/core/ops/cloud_ops.cc b/tensorflow/core/ops/cloud_ops.cc new file mode 100644 index 00000000000..89f31a46abe --- /dev/null +++ b/tensorflow/core/ops/cloud_ops.cc @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* This file registers all cloud ops. */ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("BigQueryReader") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("project_id: string") + .Attr("dataset_id: string") + .Attr("table_id: string") + .Attr("columns: list(string)") + .Attr("timestamp_millis: int") + .Attr("test_end_point: string = ''") + .Output("reader_handle: Ref(string)") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); + }) + .Doc(R"doc( +A Reader that outputs rows from a BigQuery table as tensorflow Examples. + +container: If non-empty, this reader is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this reader is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. +project_id: GCP project ID. +dataset_id: BigQuery Dataset ID. +table_id: Table to read. +columns: List of columns to read. Leave empty to read all columns. +timestamp_millis: Table snapshot timestamp in millis since epoch. Relative +(negative or zero) snapshot times are not allowed. For more details, see +'Table Decorators' in BigQuery docs. +test_end_point: Do not use. For testing purposes only. +reader_handle: The handle to reference the Reader. +)doc"); + +REGISTER_OP("GenerateBigQueryReaderPartitions") + .Attr("project_id: string") + .Attr("dataset_id: string") + .Attr("table_id: string") + .Attr("columns: list(string)") + .Attr("timestamp_millis: int") + .Attr("num_partitions: int") + .Attr("test_end_point: string = ''") + .Output("partitions: string") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +Generates serialized partition messages suitable for batch reads. + +This op should not be used directly by clients. Instead, the +bigquery_reader_ops.py file defines a clean interface to the reader. + +project_id: GCP project ID. +dataset_id: BigQuery Dataset ID. +table_id: Table to read. +columns: List of columns to read. Leave empty to read all columns. +timestamp_millis: Table snapshot timestamp in millis since epoch. Relative +(negative or zero) snapshot times are not allowed. For more details, see +'Table Decorators' in BigQuery docs. +num_partitions: Number of partitions to split the table into. +test_end_point: Do not use. For testing purposes only. +partitions: Serialized table partitions. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt index b9589d1c6fa..cfb75046640 100644 --- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt @@ -11058,77 +11058,6 @@ op { } } } -op { - name: "Empty" - input_arg { - name: "shape" - type_attr: "Tshape" - } - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "dtype" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "init" - type: "bool" - default_value { - b: false - } - } -} -op { - name: "Empty" - input_arg { - name: "shape" - type_attr: "Tshape" - } - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "dtype" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "init" - type: "bool" - default_value { - b: false - } - } - is_stateful: true -} op { name: "EncodeBase64" input_arg { @@ -14419,114 +14348,6 @@ op { } } } -op { - name: "InplaceAdd" - input_arg { - name: "value" - type_attr: "T" - } - input_arg { - name: "loc" - type_attr: "Tshape" - } - input_arg { - name: "update" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } -} -op { - name: "InplaceSubtract" - input_arg { - name: "value" - type_attr: "T" - } - input_arg { - name: "loc" - type_attr: "Tshape" - } - input_arg { - name: "update" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } -} -op { - name: "InplaceUpdate" - input_arg { - name: "value" - type_attr: "T" - } - input_arg { - name: "loc" - type_attr: "Tshape" - } - input_arg { - name: "update" - type_attr: "T" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } -} op { name: "Inv" input_arg { @@ -19650,6 +19471,32 @@ op { } is_stateful: true } +op { + name: "ParallelConcat" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "shape" + type: "shape" + } +} op { name: "ParameterizedTruncatedNormal" input_arg { diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 88ffe5e0667..cb216e70073 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -43,6 +43,14 @@ Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, width = c->UnknownDim(); height = c->UnknownDim(); } else { + // TODO(petewarden) - Remove once we have constant evaluation in C++ only. + if (size_tensor->dtype() != DT_INT32) { + return errors::InvalidArgument( + "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 " + "but got ", + DataTypeString(size_tensor->dtype()), " for input #", size_input_idx, + " in ", c->DebugString()); + } auto vec = size_tensor->vec<int32>(); height = c->MakeDim(vec(0)); width = c->MakeDim(vec(1)); @@ -74,8 +82,9 @@ Status DecodeImageShapeFn(InferenceContext* c) { channels_dim = c->MakeDim(channels); } - c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim, channels_dim})); + c->set_output(0, + c->MakeShape({InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim, channels_dim})); return Status::OK(); } @@ -555,9 +564,10 @@ REGISTER_OP("DecodeGif") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim, 3})); + c->set_output(0, + c->MakeShape({InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim, 3})); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 65ad47e7b73..e631c289c6d 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -6371,48 +6371,6 @@ op { } summary: "Computes gradients for the exponential linear (Elu) operation." } -op { - name: "Empty" - input_arg { - name: "shape" - description: "1-D `Tensor` indicating the shape of the output." - type_attr: "Tshape" - } - output_arg { - name: "output" - description: "An empty Tensor of the specified type." - type_attr: "dtype" - } - attr { - name: "dtype" - type: "type" - description: "The element type of the returned tensor." - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "init" - type: "bool" - default_value { - b: false - } - description: "`bool` indicating whether or not to zero the allocated memory." - } - summary: "Creates an empty Tensor with shape `shape` and type `dtype`." - description: "The memory can optionally be initialized. This is usually useful in\nconjunction with inplace operations." - is_stateful: true -} op { name: "EncodeBase64" input_arg { @@ -8894,132 +8852,6 @@ op { summary: "Initializes a table from a text file." description: "It inserts one key-value pair into the table for each line of the file.\nThe key and value is extracted from the whole line content, elements from the\nsplit line based on `delimiter` or the line number (starting from zero).\nWhere to extract the key and value from a line is specified by `key_index` and\n`value_index`.\n\n- A value of -1 means use the line number(starting from zero), expects `int64`.\n- A value of -2 means use the whole line content, expects `string`.\n- A value >= 0 means use the index (starting at zero) of the split line based\n on `delimiter`." } -op { - name: "InplaceAdd" - input_arg { - name: "value" - description: "A `Tensor` object that will be updated in-place." - type_attr: "T" - } - input_arg { - name: "loc" - description: "A scalar or 1-D `Tensor` indicating the indices of the first dimension\nsuch that value[loc, :] is updated." - type_attr: "Tshape" - } - input_arg { - name: "update" - description: "A `Tensor` of rank one less than `value` if `loc` is a scalar,\notherwise of rank equal to `value` that contains the new values\nthat will be added to `value`." - type_attr: "T" - } - output_arg { - name: "output" - description: "`value` where `update` has been added as appropriate." - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - summary: "Updates input `value` at `loc` by adding `update` elementwise." - description: "If `loc` is None, `value` and `update` must be the same size.\n```\nvalue += update\n```\n\nIf `loc` is a scalar, `value` has rank 1 higher than `update`\n```\nvalue[i, :] += update\n```\n\nIf `loc` is a vector, `value` has the same rank as `update`\n```\nvalue[loc, :] += update\n```\n\nIf you use this function you will almost certainly want to add\na control dependency as done in the implementation of parallel_stack to\navoid race conditions." -} -op { - name: "InplaceSubtract" - input_arg { - name: "value" - description: "A `Tensor` object that will be updated in-place." - type_attr: "T" - } - input_arg { - name: "loc" - description: "A scalar or 1-D `Tensor` indicating the indices of the first dimension\nsuch that value[loc, :] is updated." - type_attr: "Tshape" - } - input_arg { - name: "update" - description: "A `Tensor` of rank one less than `value` if `loc` is a scalar,\notherwise of rank equal to `value` that contains the new values\nthat will be subtracted from `value`." - type_attr: "T" - } - output_arg { - name: "output" - description: "`value` where `update` has been subtracted as appropriate." - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - summary: "Updates input `value` at `loc` by subtracting `update` elementwise." - description: "If `loc` is None, `value` and `update` must be the same size.\n```\nvalue -= update\n```\n\nIf `loc` is a scalar, `value` has rank 1 higher than `update`\n```\nvalue[i, :] -= update\n```\n\nIf `loc` is a vector, `value` has the same rank as `update`\n```\nvalue[loc, :] -= update\n```\n\nIf you use this function you will almost certainly want to add\na control dependency as done in the implementation of parallel_stack to\navoid race conditions." -} -op { - name: "InplaceUpdate" - input_arg { - name: "value" - description: "A `Tensor` object that will be updated in-place." - type_attr: "T" - } - input_arg { - name: "loc" - description: "A scalar or 1-D `Tensor` indicating the indices of the first dimension\nsuch that value[loc, :] is updated." - type_attr: "Tshape" - } - input_arg { - name: "update" - description: "A `Tensor` of rank one less than `value` if `loc` is a scalar,\notherwise of rank equal to `value` that contains the new values\nfor `value`." - type_attr: "T" - } - output_arg { - name: "output" - description: "`value` that has been updated accordingly." - type_attr: "T" - } - attr { - name: "T" - type: "type" - } - attr { - name: "Tshape" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - summary: "Updates input `value` at `loc` with `update`." - description: "If `loc` is None, `value` and `update` must be the same size.\n```\nvalue = update\n```\n\nIf `loc` is a scalar, `value` has rank 1 higher than `update`\n```\nvalue[i, :] = update\n```\n\nIf `loc` is a vector, `value` has the same rank as `update`\n```\nvalue[loc, :] = update\n```\n\nIf you use this function you will almost certainly want to add\na control dependency as done in the implementation of parallel_stack to\navoid race conditions." -} op { name: "Inv" input_arg { @@ -11830,6 +11662,37 @@ op { description: "Variable-size shapes are allowed by setting the corresponding shape dimensions\nto 0 in the shape attr. In this case DequeueMany will pad up to the maximum\nsize of any given element in the minibatch. See below for details." is_stateful: true } +op { + name: "ParallelConcat" + input_arg { + name: "values" + description: "Tensors to be concatenated. All must have size 1 in the first dimension\nand same shape." + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + description: "The concatenated tensor." + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "shape" + type: "shape" + description: "the final shape of the result; should be equal to the shapes of any input\nbut with the number of input values in the first dimension." + } + summary: "Concatenates a list of `N` tensors along the first dimension." + description: "The input tensors are all required to have size 1 in the first dimension.\n\nFor example:\n\n```prettyprint\n# \'x\' is [[1, 4]]\n# \'y\' is [[2, 5]]\n# \'z\' is [[3, 6]]\nparallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.\n```\n\nThe difference between concat and parallel_concat is that concat requires all\nof the inputs be computed before the operation will begin but doesn\'t require\nthat the input shapes be known during graph construction. Parallel concat\nwill copy pieces of the input into the output as they become available, in\nsome situations this can provide a performance benefit." +} op { name: "ParameterizedTruncatedNormal" input_arg { diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 168f9df2e84..09bbda63874 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -2,6 +2,7 @@ load("@protobuf//:protobuf.bzl", "cc_proto_library") load("@protobuf//:protobuf.bzl", "py_proto_library") +load("//tensorflow:tensorflow.bzl", "if_not_mobile") # configure may change the following lines WITH_GCP_SUPPORT = False @@ -207,6 +208,24 @@ def tf_additional_core_deps(): deps.append("//tensorflow/core/platform/hadoop:hadoop_file_system") return deps +# TODO(jart, jhseu): Delete when GCP is default on. +def tf_additional_cloud_op_deps(): + deps = [] + # TODO(hormati): Remove the comments below to enable BigQuery op. The op is + # not linked for now because it is under perf testing. + #if WITH_GCP_SUPPORT: + # deps = if_not_mobile(["//tensorflow/core/kernels/cloud:bigquery_reader_ops"]) + return deps + +# TODO(jart, jhseu): Delete when GCP is default on. +def tf_additional_cloud_kernel_deps(): + deps = [] + # TODO(hormati): Remove the comments below to enable BigQuery op. The op is + # not linked for now because it is under perf testing. + #if WITH_GCP_SUPPORT: + # deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"]) + return deps + def tf_additional_plugin_deps(): deps = [] if WITH_XLA_SUPPORT: diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md index a6d950dc6ed..2dcf6bcca6f 100644 --- a/tensorflow/g3doc/api_docs/python/array_ops.md +++ b/tensorflow/g3doc/api_docs/python/array_ops.md @@ -2044,7 +2044,7 @@ The attr `block_size` indicates the input block size and how the data is moved. * Chunks of data of size `block_size * block_size` from depth are rearranged into non-overlapping blocks of size `block_size x block_size` - * The width the output tensor is `input_width * block_size`, whereas the + * The width the output tensor is `input_depth * block_size`, whereas the height is `input_height * block_size`. * The depth of the input tensor must be divisible by `block_size * block_size`. diff --git a/tensorflow/g3doc/api_docs/python/contrib.distributions.bijector.md b/tensorflow/g3doc/api_docs/python/contrib.distributions.bijector.md index ee1595d8692..64487d5fe02 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.distributions.bijector.md +++ b/tensorflow/g3doc/api_docs/python/contrib.distributions.bijector.md @@ -122,6 +122,7 @@ specified then `scale += IdentityMatrix`. Otherwise specifying a `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. * <b>`scale_perturb_factor`</b>: Numeric `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. diff --git a/tensorflow/g3doc/api_docs/python/contrib.distributions.md b/tensorflow/g3doc/api_docs/python/contrib.distributions.md index ce42750b985..e639cf3defb 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.distributions.md +++ b/tensorflow/g3doc/api_docs/python/contrib.distributions.md @@ -23429,3 +23429,29 @@ Initialize the KL registrar. + +## Utilities + +- - - + +### `tf.contrib.distributions.softplus_inverse(x, name=None)` {#softplus_inverse} + +Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). + +Mathematically this op is equivalent to: + +```none +softplus_inverse = log(exp(x) - 1.) +``` + +##### Args: + + +* <b>`x`</b>: `Tensor`. Non-negative (not enforced), floating-point. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + `Tensor`. Has the same type/shape as input `x`. + + diff --git a/tensorflow/g3doc/api_docs/python/contrib.layers.md b/tensorflow/g3doc/api_docs/python/contrib.layers.md index c9bbabdd4fd..d2751e8febc 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.layers.md +++ b/tensorflow/g3doc/api_docs/python/contrib.layers.md @@ -83,7 +83,8 @@ can have speed penalty, specially in distributed settings. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability. -* <b>`center`</b>: If True, subtract `beta`. If False, `beta` is ignored. +* <b>`center`</b>: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. * <b>`scale`</b>: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. @@ -411,7 +412,8 @@ Can be used as a normalizer function for conv2d and fully_connected. * <b>`inputs`</b>: a tensor with 2 or more dimensions. The normalization occurs over all but the first dimension. -* <b>`center`</b>: If True, subtract `beta`. If False, `beta` is ignored. +* <b>`center`</b>: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. * <b>`scale`</b>: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. diff --git a/tensorflow/g3doc/api_docs/python/contrib.linalg.md b/tensorflow/g3doc/api_docs/python/contrib.linalg.md index 72c11268866..cbbffb1e783 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.linalg.md +++ b/tensorflow/g3doc/api_docs/python/contrib.linalg.md @@ -1522,6 +1522,504 @@ Return a dense (batch) matrix representing this operator. +- - - + +### `class tf.contrib.linalg.LinearOperatorScaledIdentity` {#LinearOperatorScaledIdentity} + +`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. + +This operator acts like a scaled [batch] identity matrix `A` with shape +`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a +batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is +a scaled version of the `N x N` identity matrix. + +`LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier` +(a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the +`multiplier` determines the scale for each batch member. + +```python +# Create a 2 x 2 scaled identity matrix. +operator = LinearOperatorIdentity(num_rows=2, multiplier=3.) + +operator.to_dense() +==> [[3., 0.] + [0., 3.]] + +operator.shape +==> [2, 2] + +operator.log_determinant() +==> 2 * Log[3] + +x = ... Shape [2, 4] Tensor +operator.apply(x) +==> 3 * x + +y = tf.random_normal(shape=[3, 2, 4]) +# Note that y.shape is compatible with operator.shape because operator.shape +# is broadcast to [3, 2, 2]. +x = operator.solve(y) +==> 3 * x + +# Create a 2-batch of 2x2 identity matrices +operator = LinearOperatorIdentity(num_rows=2, multiplier=5.) +operator.to_dense() +==> [[[5., 0.] + [0., 5.]], + [[5., 0.] + [0., 5.]]] + +x = ... Shape [2, 2, 3] +operator.apply(x) +==> 5 * x + +# Here the operator and x have different batch_shape, and are broadcast. +x = ... Shape [1, 2, 3] +operator.apply(x) +==> 5 * x +``` + +### Shape compatibility + +This operator acts on [batch] matrix with compatible shape. +`x` is a batch matrix with compatible shape for `apply` and `solve` if + +``` +operator.shape = [B1,...,Bb] + [N, N], with b >= 0 +x.shape = [C1,...,Cc] + [N, R], +and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] +``` + +### Performance + +* `operator.apply(x)` is `O(D1*...*Dd*N*R)` +* `operator.solve(x)` is `O(D1*...*Dd*N*R)` +* `operator.determinant()` is `O(D1*...*Dd)` + +#### Matrix property hints + +This `LinearOperator` is initialized with boolean flags of the form `is_X`, +for `X = non_singular, self_adjoint, positive_definite`. +These have the following meaning +* If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. +* If `is_X == False`, callers should expect the operator to not have `X`. +* If `is_X == None` (the default), callers should have no expectation either + way. +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.__init__(num_rows, multiplier, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, assert_proper_shapes=False, name='LinearOperatorScaledIdentity')` {#LinearOperatorScaledIdentity.__init__} + +Initialize a `LinearOperatorScaledIdentity`. + +The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which +determines the size of each identity matrix, and a `multiplier`, +which defines `dtype`, batch shape, and scale of each matrix. + +This operator is able to broadcast the leading (batch) dimensions. + +##### Args: + + +* <b>`num_rows`</b>: Scalar non-negative integer `Tensor`. Number of rows in the + corresponding identity matrix. +* <b>`multiplier`</b>: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). +* <b>`is_non_singular`</b>: Expect that this operator is non-singular. +* <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian + transpose. +* <b>`is_positive_definite`</b>: Expect that this operator is positive definite. +* <b>`assert_proper_shapes`</b>: Python `bool`. If `False`, only perform static + checks that initialization and method arguments have proper shape. + If `True`, and static checks are inconclusive, add asserts to the graph. +* <b>`name`</b>: A name for this `LinearOperator` + +##### Raises: + + +* <b>`ValueError`</b>: If `num_rows` is determined statically to be non-scalar, or + negative. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.add_to_tensor(mat, name='add_to_tensor')` {#LinearOperatorScaledIdentity.add_to_tensor} + +Add matrix represented by this operator to `mat`. Equiv to `I + mat`. + +##### Args: + + +* <b>`mat`</b>: `Tensor` with same `dtype` and shape broadcastable to `self`. +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + A `Tensor` with broadcast shape and same `dtype` as `self`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.apply(x, adjoint=False, name='apply')` {#LinearOperatorScaledIdentity.apply} + +Transform `x` with left multiplication: `x --> Ax`. + +##### Args: + + +* <b>`x`</b>: `Tensor` with compatible shape and same `dtype` as `self`. + See class docstring for definition of compatibility. +* <b>`adjoint`</b>: Python `bool`. If `True`, left multiply by the adjoint. +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + A `Tensor` with shape `[..., M, R]` and same `dtype` as `self`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.assert_non_singular(name='assert_non_singular')` {#LinearOperatorScaledIdentity.assert_non_singular} + +Returns an `Op` that asserts this operator is non singular. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.assert_positive_definite(name='assert_positive_definite')` {#LinearOperatorScaledIdentity.assert_positive_definite} + +Returns an `Op` that asserts this operator is positive definite. + +Here, positive definite means the real part of all eigenvalues is positive. +We do not require the operator to be self-adjoint. + +##### Args: + + +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + An `Op` that asserts this operator is positive definite. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperatorScaledIdentity.assert_self_adjoint} + +Returns an `Op` that asserts this operator is self-adjoint. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.batch_shape` {#LinearOperatorScaledIdentity.batch_shape} + +`TensorShape` of batch dimensions of this `LinearOperator`. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns +`TensorShape([B1,...,Bb])`, equivalent to `A.get_shape()[:-2]` + +##### Returns: + + `TensorShape`, statically determined, may be undefined. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.batch_shape_dynamic(name='batch_shape_dynamic')` {#LinearOperatorScaledIdentity.batch_shape_dynamic} + +Shape of batch dimensions of this operator, determined at runtime. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding +`[B1,...,Bb]`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `int32` `Tensor` + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.determinant(name='det')` {#LinearOperatorScaledIdentity.determinant} + +Determinant for every batch member. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.domain_dimension` {#LinearOperatorScaledIdentity.domain_dimension} + +Dimension (in the sense of vector spaces) of the domain of this operator. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `N`. + +##### Returns: + + `Dimension` object. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.domain_dimension_dynamic(name='domain_dimension_dynamic')` {#LinearOperatorScaledIdentity.domain_dimension_dynamic} + +Dimension (in the sense of vector spaces) of the domain of this operator. + +Determined at runtime. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `N`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op`. + +##### Returns: + + `int32` `Tensor` + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.dtype` {#LinearOperatorScaledIdentity.dtype} + +The `DType` of `Tensor`s handled by this `LinearOperator`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.graph_parents` {#LinearOperatorScaledIdentity.graph_parents} + +List of graph dependencies of this `LinearOperator`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.is_non_singular` {#LinearOperatorScaledIdentity.is_non_singular} + + + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.is_positive_definite` {#LinearOperatorScaledIdentity.is_positive_definite} + + + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.is_self_adjoint` {#LinearOperatorScaledIdentity.is_self_adjoint} + + + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.log_abs_determinant(name='log_abs_det')` {#LinearOperatorScaledIdentity.log_abs_determinant} + +Log absolute value of determinant for every batch member. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.multiplier` {#LinearOperatorScaledIdentity.multiplier} + +The [batch] scalar `Tensor`, `c` in `cI`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.name` {#LinearOperatorScaledIdentity.name} + +Name prepended to all ops created by this `LinearOperator`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.range_dimension` {#LinearOperatorScaledIdentity.range_dimension} + +Dimension (in the sense of vector spaces) of the range of this operator. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `M`. + +##### Returns: + + `Dimension` object. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.range_dimension_dynamic(name='range_dimension_dynamic')` {#LinearOperatorScaledIdentity.range_dimension_dynamic} + +Dimension (in the sense of vector spaces) of the range of this operator. + +Determined at runtime. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `M`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op`. + +##### Returns: + + `int32` `Tensor` + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.shape` {#LinearOperatorScaledIdentity.shape} + +`TensorShape` of this `LinearOperator`. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns +`TensorShape([B1,...,Bb, M, N])`, equivalent to `A.get_shape()`. + +##### Returns: + + `TensorShape`, statically determined, may be undefined. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.shape_dynamic(name='shape_dynamic')` {#LinearOperatorScaledIdentity.shape_dynamic} + +Shape of this `LinearOperator`, determined at runtime. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding +`[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `int32` `Tensor` + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.solve(rhs, adjoint=False, name='solve')` {#LinearOperatorScaledIdentity.solve} + +Solve `R` (batch) systems of equations exactly: `A X = rhs`. + +Examples: + +```python +# Create an operator acting like a 10 x 2 x 2 matrix. +operator = LinearOperator(...) +operator.shape # = 10 x 2 x 2 + +# Solve one linear system (R = 1) for every member of the length 10 batch. +RHS = ... # shape 10 x 2 x 1 +X = operator.solve(RHS) # shape 10 x 2 x 1 + +# Solve five linear systems (R = 5) for every member of the length 10 batch. +RHS = ... # shape 10 x 2 x 5 +X = operator.solve(RHS) +X[3, :, 2] # Solution to the linear system A[3, :, :] X = RHS[3, :, 2] +``` + +##### Args: + + +* <b>`rhs`</b>: `Tensor` with same `dtype` as this operator and compatible shape. + See class docstring for definition of compatibility. +* <b>`adjoint`</b>: Python `bool`. If `True`, solve the system involving the adjoint + of this `LinearOperator`. +* <b>`name`</b>: A name scope to use for ops added by this method. + +##### Returns: + + `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. + +##### Raises: + + +* <b>`ValueError`</b>: If self.is_non_singular is False. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.tensor_rank` {#LinearOperatorScaledIdentity.tensor_rank} + +Rank (in the sense of tensors) of matrix corresponding to this operator. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + Python integer, or None if the tensor rank is undefined. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.tensor_rank_dynamic(name='tensor_rank_dynamic')` {#LinearOperatorScaledIdentity.tensor_rank_dynamic} + +Rank (in the sense of tensors) of matrix corresponding to this operator. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `int32` `Tensor`, determined at runtime. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.to_dense(name='to_dense')` {#LinearOperatorScaledIdentity.to_dense} + +Return a dense (batch) matrix representing this operator. + + + - - - ### `class tf.contrib.linalg.LinearOperatorMatrix` {#LinearOperatorMatrix} diff --git a/tensorflow/g3doc/api_docs/python/contrib.training.md b/tensorflow/g3doc/api_docs/python/contrib.training.md index 89ac2b5538f..3ab824a38e5 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.training.md +++ b/tensorflow/g3doc/api_docs/python/contrib.training.md @@ -986,8 +986,10 @@ operations that depend on fixed batch_size would fail. * <b>`tensors`</b>: The list or dictionary of tensors, representing a single element, to bucket. Nested lists are not supported. * <b>`which_bucket`</b>: An `int32` scalar Tensor taking a value in `[0, num_buckets)`. -* <b>`batch_size`</b>: The new batch size pulled from the queue - (python int or int32 scalar). +* <b>`batch_size`</b>: The new batch size pulled from the queue (all queues will have + the same size). If a list is passed in then each bucket will have a + different batch_size. + (python int, int32 scalar or iterable of integers of length num_buckets). * <b>`num_buckets`</b>: A python integer, the number of buckets. * <b>`num_threads`</b>: An integer. The number of threads enqueuing `tensors`. * <b>`capacity`</b>: An integer. The maximum number of minibatches in the top queue, @@ -1019,7 +1021,8 @@ operations that depend on fixed batch_size would fail. * <b>`ValueError`</b>: If the `shapes` are not specified, and cannot be - inferred from the elements of `tensors`. + inferred from the elements of `tensors` or if batch_size is a sequence + but it's length != num_buckets. - - - @@ -1039,8 +1042,10 @@ bucket the given `input_length` belongs to. See the documentation for * <b>`input_length`</b>: `int32` scalar `Tensor`, the sequence length of tensors. * <b>`tensors`</b>: The list or dictionary of tensors, representing a single element, to bucket. Nested lists are not supported. -* <b>`batch_size`</b>: The new batch size pulled from the queue - (python int or int32 scalar). +* <b>`batch_size`</b>: The new batch size pulled from the queue (all queues will have + the same size). If a list is passed in then each bucket will have a + different batch_size. + (python int, int32 scalar or iterable of integers of length num_buckets). * <b>`bucket_boundaries`</b>: int list, increasing non-negative numbers. The edges of the buckets to use when bucketing tensors. Two extra buckets are created, one for `input_length < bucket_boundaries[0]` and @@ -1075,6 +1080,7 @@ bucket the given `input_length` belongs to. See the documentation for * <b>`TypeError`</b>: if `bucket_boundaries` is not a list of python integers. * <b>`ValueError`</b>: if `bucket_boundaries` is empty or contains non-increasing - values. + values or if batch_size is a list and it's length doesn't equal the number + of buckets. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.depth_to_space.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.depth_to_space.md index ef74b4d54a4..03dc6bb3b0d 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.depth_to_space.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.depth_to_space.md @@ -10,7 +10,7 @@ The attr `block_size` indicates the input block size and how the data is moved. * Chunks of data of size `block_size * block_size` from depth are rearranged into non-overlapping blocks of size `block_size x block_size` - * The width the output tensor is `input_width * block_size`, whereas the + * The width the output tensor is `input_depth * block_size`, whereas the height is `input_height * block_size`. * The depth of the input tensor must be divisible by `block_size * block_size`. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.TaggedRunMetadata.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.TaggedRunMetadata.md index 8dc62c4c18c..788d2066ad7 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.TaggedRunMetadata.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.summary.TaggedRunMetadata.md @@ -1,185 +1,4 @@ -- - - - -#### `tf.summary.TaggedRunMetadata.ByteSize()` {#TaggedRunMetadata.ByteSize} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.Clear()` {#TaggedRunMetadata.Clear} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.ClearExtension(extension_handle)` {#TaggedRunMetadata.ClearExtension} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.ClearField(field_name)` {#TaggedRunMetadata.ClearField} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.CopyFrom(other_msg)` {#TaggedRunMetadata.CopyFrom} - -Copies the content of the specified message into the current message. - -The method clears the current message and then merges the specified -message using MergeFrom. - -##### Args: - - -* <b>`other_msg`</b>: Message to copy into the current one. - - -- - - - -#### `tf.summary.TaggedRunMetadata.DiscardUnknownFields()` {#TaggedRunMetadata.DiscardUnknownFields} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.FindInitializationErrors()` {#TaggedRunMetadata.FindInitializationErrors} - -Finds required fields which are not initialized. - -##### Returns: - - A list of strings. Each string is a path to an uninitialized field from - the top-level message, e.g. "foo.bar[5].baz". - - -- - - - -#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.HasExtension(extension_handle)` {#TaggedRunMetadata.HasExtension} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.HasField(field_name)` {#TaggedRunMetadata.HasField} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.IsInitialized(errors=None)` {#TaggedRunMetadata.IsInitialized} - -Checks if all required fields of a message are set. - -##### Args: - - -* <b>`errors`</b>: A list which, if provided, will be populated with the field - paths of all missing required fields. - -##### Returns: - - True iff the specified message has all required fields set. - - -- - - - -#### `tf.summary.TaggedRunMetadata.ListFields()` {#TaggedRunMetadata.ListFields} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.MergeFrom(msg)` {#TaggedRunMetadata.MergeFrom} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.MergeFromString(serialized)` {#TaggedRunMetadata.MergeFromString} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.ParseFromString(serialized)` {#TaggedRunMetadata.ParseFromString} - -Parse serialized protocol buffer data into this message. - -Like MergeFromString(), except we clear the object first and -do not return the value that MergeFromString returns. - - -- - - - -#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.SerializePartialToString()` {#TaggedRunMetadata.SerializePartialToString} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.SerializeToString()` {#TaggedRunMetadata.SerializeToString} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.SetInParent()` {#TaggedRunMetadata.SetInParent} - -Sets the _cached_byte_size_dirty bit to true, -and propagates this to our listener iff this was a state change. - - -- - - - -#### `tf.summary.TaggedRunMetadata.WhichOneof(oneof_name)` {#TaggedRunMetadata.WhichOneof} - -Returns the name of the currently set field inside a oneof, or None. - - -- - - - -#### `tf.summary.TaggedRunMetadata.__deepcopy__(memo=None)` {#TaggedRunMetadata.__deepcopy__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__eq__(other)` {#TaggedRunMetadata.__eq__} - - - - - - - #### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__} @@ -187,66 +6,3 @@ Returns the name of the currently set field inside a oneof, or None. Support the pickle protocol. -- - - - -#### `tf.summary.TaggedRunMetadata.__hash__()` {#TaggedRunMetadata.__hash__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__init__(**kwargs)` {#TaggedRunMetadata.__init__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__ne__(other_msg)` {#TaggedRunMetadata.__ne__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__repr__()` {#TaggedRunMetadata.__repr__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__setstate__(state)` {#TaggedRunMetadata.__setstate__} - -Support the pickle protocol. - - -- - - - -#### `tf.summary.TaggedRunMetadata.__str__()` {#TaggedRunMetadata.__str__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__unicode__()` {#TaggedRunMetadata.__unicode__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.run_metadata` {#TaggedRunMetadata.run_metadata} - -Magic attribute generated for "run_metadata" proto field. - - -- - - - -#### `tf.summary.TaggedRunMetadata.tag` {#TaggedRunMetadata.tag} - -Magic attribute generated for "tag" proto field. - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf_debug.DumpingDebugWrapperSession.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf_debug.DumpingDebugWrapperSession.md index e655a570507..f86f63d7d9e 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf_debug.DumpingDebugWrapperSession.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf_debug.DumpingDebugWrapperSession.md @@ -32,34 +32,16 @@ Constructor of DumpingDebugWrapperSession. `session_root`. The subdirectories' names has the following pattern: run_<epoch_time_stamp>_<uuid> E.g., run_1480734393835964_ad4c953a85444900ae79fc1b652fb324 -* <b>`watch_fn`</b>: (`Callable`) A Callable of the following signature: - ``` - def watch_fn(fetches, feeds): - # Args: - # fetches: the fetches to the `Session.run()` call. - # feeds: the feeds to the `Session.run()` call. - # - # Returns: (node_name_regex_whitelist, op_type_regex_whitelist) - # debug_ops: (str or list of str) Debug op(s) to be used by the - # debugger in this run() call. - # node_name_regex_whitelist: Regular-expression whitelist for node - # name. Same as the corresponding arg to `debug_util.watch_graph`. - # op_type_regex_whiteslit: Regular-expression whitelist for op type. - # Same as the corresponding arg to `debug_util.watch_graph`. - # - # Both or either can be None. If both are set, the two whitelists - # will operate in a logical AND relation. This is consistent with - # `debug_utils.watch_graph()`. - ``` +* <b>`watch_fn`</b>: (`Callable`) A Callable that can be used to define per-run + debug ops and watched tensors. See the doc of + `NonInteractiveDebugWrapperSession.__init__()` for details. * <b>`log_usage`</b>: (`bool`) whether the usage of this class is to be logged. ##### Raises: * <b>`ValueError`</b>: If `session_root` is an existing and non-empty directory or - if - `session_root` is a file. -* <b>`TypeError`</b>: If a non-None `watch_fn` is specified and it is not callable. + if `session_root` is a file. - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.merge_all_summaries.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.merge_all_summaries.md deleted file mode 100644 index bf17320a5a3..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.merge_all_summaries.md +++ /dev/null @@ -1,17 +0,0 @@ -### `tf.merge_all_summaries(*args, **kwargs)` {#merge_all_summaries} - -Merges all summaries collected in the default graph. (deprecated) - -THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. -Instructions for updating: -Please switch to tf.summary.merge_all. - - Args: - key: `GraphKey` used to collect the summaries. Defaults to - `GraphKeys.SUMMARIES`. - - Returns: - If no summaries were collected, returns None. Otherwise returns a scalar - `Tensor` of type `string` containing the serialized `Summary` protocol - buffer resulting from the merging. - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image_summary.md deleted file mode 100644 index 6220d3641bc..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image_summary.md +++ /dev/null @@ -1,49 +0,0 @@ -### `tf.image_summary(*args, **kwargs)` {#image_summary} - -Outputs a `Summary` protocol buffer with images. (deprecated) - -THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. -Instructions for updating: -Please switch to tf.summary.image. Note that tf.summary.histogram uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on the scope they are created in. Also, the max_images argument was renamed to max_outputs. - - The summary has up to `max_images` summary values containing images. The - images are built from `tensor` which must be 4-D with shape `[batch_size, - height, width, channels]` and where `channels` can be: - - * 1: `tensor` is interpreted as Grayscale. - * 3: `tensor` is interpreted as RGB. - * 4: `tensor` is interpreted as RGBA. - - The images have the same number of channels as the input tensor. For float - input, the values are normalized one image at a time to fit in the range - `[0, 255]`. `uint8` values are unchanged. The op uses two different - normalization algorithms: - - * If the input values are all positive, they are rescaled so the largest one - is 255. - - * If any input value is negative, the values are shifted so input value 0.0 - is at 127. They are then rescaled so that either the smallest value is 0, - or the largest one is 255. - - The `tag` argument is a scalar `Tensor` of type `string`. It is used to - build the `tag` of the summary values: - - * If `max_images` is 1, the summary value tag is '*tag*/image'. - * If `max_images` is greater than 1, the summary value tags are - generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. - - Args: - tag: A scalar `Tensor` of type `string`. Used to build the `tag` - of the summary values. - tensor: A 4-D `uint8` or `float32` `Tensor` of shape `[batch_size, height, - width, channels]` where `channels` is 1, 3, or 4. - max_images: Max number of batch elements to generate images for. - collections: Optional list of ops.GraphKeys. The collections to add the - summary to. Defaults to [ops.GraphKeys.SUMMARIES] - name: A name for the operation (optional). - - Returns: - A scalar `Tensor` of type `string`. The serialized `Summary` protocol - buffer. - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.summary.SummaryDescription.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.summary.SummaryDescription.md index bce704ef4f2..19532f7cc33 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.summary.SummaryDescription.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.summary.SummaryDescription.md @@ -1,185 +1,4 @@ -- - - - -#### `tf.summary.SummaryDescription.ByteSize()` {#SummaryDescription.ByteSize} - - - - -- - - - -#### `tf.summary.SummaryDescription.Clear()` {#SummaryDescription.Clear} - - - - -- - - - -#### `tf.summary.SummaryDescription.ClearExtension(extension_handle)` {#SummaryDescription.ClearExtension} - - - - -- - - - -#### `tf.summary.SummaryDescription.ClearField(field_name)` {#SummaryDescription.ClearField} - - - - -- - - - -#### `tf.summary.SummaryDescription.CopyFrom(other_msg)` {#SummaryDescription.CopyFrom} - -Copies the content of the specified message into the current message. - -The method clears the current message and then merges the specified -message using MergeFrom. - -##### Args: - - -* <b>`other_msg`</b>: Message to copy into the current one. - - -- - - - -#### `tf.summary.SummaryDescription.DiscardUnknownFields()` {#SummaryDescription.DiscardUnknownFields} - - - - -- - - - -#### `tf.summary.SummaryDescription.FindInitializationErrors()` {#SummaryDescription.FindInitializationErrors} - -Finds required fields which are not initialized. - -##### Returns: - - A list of strings. Each string is a path to an uninitialized field from - the top-level message, e.g. "foo.bar[5].baz". - - -- - - - -#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString} - - - - -- - - - -#### `tf.summary.SummaryDescription.HasExtension(extension_handle)` {#SummaryDescription.HasExtension} - - - - -- - - - -#### `tf.summary.SummaryDescription.HasField(field_name)` {#SummaryDescription.HasField} - - - - -- - - - -#### `tf.summary.SummaryDescription.IsInitialized(errors=None)` {#SummaryDescription.IsInitialized} - -Checks if all required fields of a message are set. - -##### Args: - - -* <b>`errors`</b>: A list which, if provided, will be populated with the field - paths of all missing required fields. - -##### Returns: - - True iff the specified message has all required fields set. - - -- - - - -#### `tf.summary.SummaryDescription.ListFields()` {#SummaryDescription.ListFields} - - - - -- - - - -#### `tf.summary.SummaryDescription.MergeFrom(msg)` {#SummaryDescription.MergeFrom} - - - - -- - - - -#### `tf.summary.SummaryDescription.MergeFromString(serialized)` {#SummaryDescription.MergeFromString} - - - - -- - - - -#### `tf.summary.SummaryDescription.ParseFromString(serialized)` {#SummaryDescription.ParseFromString} - -Parse serialized protocol buffer data into this message. - -Like MergeFromString(), except we clear the object first and -do not return the value that MergeFromString returns. - - -- - - - -#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension} - - - - -- - - - -#### `tf.summary.SummaryDescription.SerializePartialToString()` {#SummaryDescription.SerializePartialToString} - - - - -- - - - -#### `tf.summary.SummaryDescription.SerializeToString()` {#SummaryDescription.SerializeToString} - - - - -- - - - -#### `tf.summary.SummaryDescription.SetInParent()` {#SummaryDescription.SetInParent} - -Sets the _cached_byte_size_dirty bit to true, -and propagates this to our listener iff this was a state change. - - -- - - - -#### `tf.summary.SummaryDescription.WhichOneof(oneof_name)` {#SummaryDescription.WhichOneof} - -Returns the name of the currently set field inside a oneof, or None. - - -- - - - -#### `tf.summary.SummaryDescription.__deepcopy__(memo=None)` {#SummaryDescription.__deepcopy__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__eq__(other)` {#SummaryDescription.__eq__} - - - - - - - #### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__} @@ -187,59 +6,3 @@ Returns the name of the currently set field inside a oneof, or None. Support the pickle protocol. -- - - - -#### `tf.summary.SummaryDescription.__hash__()` {#SummaryDescription.__hash__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__init__(**kwargs)` {#SummaryDescription.__init__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__ne__(other_msg)` {#SummaryDescription.__ne__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__repr__()` {#SummaryDescription.__repr__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__setstate__(state)` {#SummaryDescription.__setstate__} - -Support the pickle protocol. - - -- - - - -#### `tf.summary.SummaryDescription.__str__()` {#SummaryDescription.__str__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__unicode__()` {#SummaryDescription.__unicode__} - - - - -- - - - -#### `tf.summary.SummaryDescription.type_hint` {#SummaryDescription.type_hint} - -Magic attribute generated for "type_hint" proto field. - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.tables_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.tables_initializer.md new file mode 100644 index 00000000000..f278bd57e69 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.tables_initializer.md @@ -0,0 +1,14 @@ +### `tf.tables_initializer(name='init_all_tables')` {#tables_initializer} + +Returns an Op that initializes all tables of the default graph. + +##### Args: + + +* <b>`name`</b>: Optional name for the initialization op. + +##### Returns: + + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.test.TestCase.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.test.TestCase.md index 4d4330488f6..e9e8a2684ca 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.test.TestCase.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.test.TestCase.md @@ -173,6 +173,125 @@ Checks that for all elements of farray1 and farray2 * <b>`err`</b>: a float value. +- - - + +#### `tf.test.TestCase.assertBetween(value, minv, maxv, msg=None)` {#TestCase.assertBetween} + +Asserts that value is between minv and maxv (inclusive). + + +- - - + +#### `tf.test.TestCase.assertCommandFails(command, regexes, env=None, close_fds=True, msg=None)` {#TestCase.assertCommandFails} + +Asserts a shell command fails and the error matches a regex in a list. + +##### Args: + + +* <b>`command`</b>: List or string representing the command to run. +* <b>`regexes`</b>: the list of regular expression strings. +* <b>`env`</b>: Dictionary of environment variable settings. +* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after + forking. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertCommandSucceeds(command, regexes=('',), env=None, close_fds=True, msg=None)` {#TestCase.assertCommandSucceeds} + +Asserts that a shell command succeeds (i.e. exits with code 0). + +##### Args: + + +* <b>`command`</b>: List or string representing the command to run. +* <b>`regexes`</b>: List of regular expression byte strings that match success. +* <b>`env`</b>: Dictionary of environment variable settings. +* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after + forking. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertContainsExactSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsExactSubsequence} + +Assert that "container" contains "subsequence" as an exact subsequence. + +Asserts that "container" contains all the elements of "subsequence", in +order, and without other elements interspersed. For example, [1, 2, 3] is an +exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0]. + +##### Args: + + +* <b>`container`</b>: the list we're testing for subsequence inclusion. +* <b>`subsequence`</b>: the list we hope will be an exact subsequence of container. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertContainsInOrder(strings, target, msg=None)` {#TestCase.assertContainsInOrder} + +Asserts that the strings provided are found in the target in order. + +This may be useful for checking HTML output. + +##### Args: + + +* <b>`strings`</b>: A list of strings, such as [ 'fox', 'dog' ] +* <b>`target`</b>: A target string in which to look for the strings, such as + 'The quick brown fox jumped over the lazy dog'. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertContainsSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsSubsequence} + +Assert that "container" contains "subsequence" as a subsequence. + +Asserts that "container" contains all the elements of "subsequence", in +order, but possibly with other elements interspersed. For example, [1, 2, 3] +is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0]. + +##### Args: + + +* <b>`container`</b>: the list we're testing for subsequence inclusion. +* <b>`subsequence`</b>: the list we hope will be a subsequence of container. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertContainsSubset(expected_subset, actual_set, msg=None)` {#TestCase.assertContainsSubset} + +Checks whether actual iterable is a superset of expected iterable. + + +- - - + +#### `tf.test.TestCase.assertCountEqual(*args, **kwargs)` {#TestCase.assertCountEqual} + +An unordered sequence specific comparison. + +Equivalent to assertItemsEqual(). This method is a compatibility layer +for Python 3k, since 2to3 does not convert assertItemsEqual() calls into +assertCountEqual() calls. + +##### Args: + + +* <b>`expected_seq`</b>: A sequence containing elements we are expecting. +* <b>`actual_seq`</b>: The sequence that we are testing. +* <b>`msg`</b>: The message to be printed if the test fails. + + - - - #### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual} @@ -195,9 +314,48 @@ Checks whether actual is a superset of expected. - - - -#### `tf.test.TestCase.assertDictEqual(d1, d2, msg=None)` {#TestCase.assertDictEqual} +#### `tf.test.TestCase.assertDictEqual(a, b, msg=None)` {#TestCase.assertDictEqual} + +Raises AssertionError if a and b are not equal dictionaries. + +##### Args: +* <b>`a`</b>: A dict, the expected value. +* <b>`b`</b>: A dict, the actual value. +* <b>`msg`</b>: An optional str, the associated message. + +##### Raises: + + +* <b>`AssertionError`</b>: if the dictionaries are not equal. + + +- - - + +#### `tf.test.TestCase.assertEmpty(container, msg=None)` {#TestCase.assertEmpty} + +Assert that an object has zero length. + +##### Args: + + +* <b>`container`</b>: Anything that implements the collections.Sized interface. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertEndsWith(actual, expected_end, msg=None)` {#TestCase.assertEndsWith} + +Assert that actual.endswith(expected_end) is True. + +##### Args: + + +* <b>`actual`</b>: str +* <b>`expected_end`</b>: str +* <b>`msg`</b>: Optional message to report on failure. - - - @@ -282,10 +440,11 @@ Included for symmetry with assertIsNone. - - - -#### `tf.test.TestCase.assertItemsEqual(expected_seq, actual_seq, msg=None)` {#TestCase.assertItemsEqual} +#### `tf.test.TestCase.assertItemsEqual(*args, **kwargs)` {#TestCase.assertItemsEqual} -An unordered sequence specific comparison. It asserts that -actual_seq and expected_seq have the same element counts. +An unordered sequence specific comparison. + +It asserts that actual_seq and expected_seq have the same element counts. Equivalent to:: self.assertEqual(Counter(iter(actual_seq)), @@ -298,6 +457,30 @@ Asserts that each element has the same count in both sequences. - [0, 1, 1] and [1, 0, 1] compare equal. - [0, 0, 1] and [0, 1] compare unequal. +##### Args: + + +* <b>`expected_seq`</b>: A sequence containing elements we are expecting. +* <b>`actual_seq`</b>: The sequence that we are testing. +* <b>`msg`</b>: The message to be printed if the test fails. + + +- - - + +#### `tf.test.TestCase.assertJsonEqual(first, second, msg=None)` {#TestCase.assertJsonEqual} + +Asserts that the JSON objects defined in two strings are equal. + +A summary of the differences will be included in the failure message +using assertSameStructure. + +##### Args: + + +* <b>`first`</b>: A string contining JSON to decode and compare to second. +* <b>`second`</b>: A string contining JSON to decode and compare to first. +* <b>`msg`</b>: Additional text to include in the failure message. + - - - @@ -367,6 +550,13 @@ if not. * <b>`msg`</b>: An optional string message to append to the failure message. +- - - + +#### `tf.test.TestCase.assertNoCommonElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertNoCommonElements} + +Checks whether actual iterable and expected iterable are disjoint. + + - - - #### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual} @@ -397,6 +587,33 @@ as significant digits (measured from the most signficant digit). Objects that are equal automatically fail. +- - - + +#### `tf.test.TestCase.assertNotEmpty(container, msg=None)` {#TestCase.assertNotEmpty} + +Assert that an object has non-zero length. + +##### Args: + + +* <b>`container`</b>: Anything that implements the collections.Sized interface. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertNotEndsWith(actual, unexpected_end, msg=None)` {#TestCase.assertNotEndsWith} + +Assert that actual.endswith(unexpected_end) is False. + +##### Args: + + +* <b>`actual`</b>: str +* <b>`unexpected_end`</b>: str +* <b>`msg`</b>: Optional message to report on failure. + + - - - #### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual} @@ -434,6 +651,20 @@ Included for symmetry with assertIsInstance. Fail the test if the text matches the regular expression. +- - - + +#### `tf.test.TestCase.assertNotStartsWith(actual, unexpected_start, msg=None)` {#TestCase.assertNotStartsWith} + +Assert that actual.startswith(unexpected_start) is False. + +##### Args: + + +* <b>`actual`</b>: str +* <b>`unexpected_start`</b>: str +* <b>`msg`</b>: Optional message to report on failure. + + - - - #### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals} @@ -508,6 +739,38 @@ Asserts that the message in a raised exception matches a regexp. * <b>`kwargs`</b>: Extra kwargs. +- - - + +#### `tf.test.TestCase.assertRaisesWithLiteralMatch(expected_exception, expected_exception_message, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithLiteralMatch} + +Asserts that the message in a raised exception equals the given string. + +Unlike assertRaisesRegexp, this method takes a literal string, not +a regular expression. + +with self.assertRaisesWithLiteralMatch(ExType, 'message'): + DoSomething() + +##### Args: + + +* <b>`expected_exception`</b>: Exception class expected to be raised. +* <b>`expected_exception_message`</b>: String message expected in the raised + exception. For a raise exception e, expected_exception_message must + equal str(e). +* <b>`callable_obj`</b>: Function to be called, or None to return a context. +* <b>`args`</b>: Extra args. +* <b>`kwargs`</b>: Extra kwargs. + +##### Returns: + + A context manager if callable_obj is None. Otherwise, None. + +##### Raises: + + self.failureException if callable_obj does not raise a macthing exception. + + - - - #### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch} @@ -532,6 +795,71 @@ predicate search. exception. +- - - + +#### `tf.test.TestCase.assertRaisesWithRegexpMatch(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithRegexpMatch} + +Asserts that the message in a raised exception matches the given regexp. + +This is just a wrapper around assertRaisesRegexp. Please use +assertRaisesRegexp instead of assertRaisesWithRegexpMatch. + +##### Args: + + +* <b>`expected_exception`</b>: Exception class expected to be raised. +* <b>`expected_regexp`</b>: Regexp (re pattern object or string) expected to be + found in error message. +* <b>`callable_obj`</b>: Function to be called, or None to return a context. +* <b>`args`</b>: Extra args. +* <b>`kwargs`</b>: Extra keyword args. + +##### Returns: + + A context manager if callable_obj is None. Otherwise, None. + +##### Raises: + + self.failureException if callable_obj does not raise a macthing exception. + + +- - - + +#### `tf.test.TestCase.assertRegexMatch(actual_str, regexes, message=None)` {#TestCase.assertRegexMatch} + +Asserts that at least one regex in regexes matches str. + + If possible you should use assertRegexpMatches, which is a simpler + version of this method. assertRegexpMatches takes a single regular + expression (a string or re compiled object) instead of a list. + + Notes: + 1. This function uses substring matching, i.e. the matching + succeeds if *any* substring of the error message matches *any* + regex in the list. This is more convenient for the user than + full-string matching. + + 2. If regexes is the empty list, the matching will always fail. + + 3. Use regexes=[''] for a regex that will always pass. + + 4. '.' matches any single character *except* the newline. To + match any character, use '(.| +)'. + + 5. '^' matches the beginning of each line, not just the beginning + of the string. Similarly, '$' matches the end of each line. + + 6. An exception will be thrown if regexes contains an invalid + regex. + + Args: + actual_str: The string we try to match with the items in regexes. + regexes: The regular expressions we want to match against str. + See "Notes" above for detailed notes on how this is interpreted. + message: The message to be printed if the test fails. + + - - - #### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches} @@ -539,6 +867,79 @@ predicate search. Fail the test unless the text matches the regular expression. +- - - + +#### `tf.test.TestCase.assertSameElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertSameElements} + +Assert that two sequences have the same elements (in any order). + +This method, unlike assertItemsEqual, doesn't care about any +duplicates in the expected and actual sequences. + + >> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1]) + # Doesn't raise an AssertionError + +If possible, you should use assertItemsEqual instead of +assertSameElements. + +##### Args: + + +* <b>`expected_seq`</b>: A sequence containing elements we are expecting. +* <b>`actual_seq`</b>: The sequence that we are testing. +* <b>`msg`</b>: The message to be printed if the test fails. + + +- - - + +#### `tf.test.TestCase.assertSameStructure(a, b, aname='a', bname='b', msg=None)` {#TestCase.assertSameStructure} + +Asserts that two values contain the same structural content. + +The two arguments should be data trees consisting of trees of dicts and +lists. They will be deeply compared by walking into the contents of dicts +and lists; other items will be compared using the == operator. +If the two structures differ in content, the failure message will indicate +the location within the structures where the first difference is found. +This may be helpful when comparing large structures. + +##### Args: + + +* <b>`a`</b>: The first structure to compare. +* <b>`b`</b>: The second structure to compare. +* <b>`aname`</b>: Variable name to use for the first structure in assertion messages. +* <b>`bname`</b>: Variable name to use for the second structure. +* <b>`msg`</b>: Additional text to include in the failure message. + + +- - - + +#### `tf.test.TestCase.assertSequenceAlmostEqual(expected_seq, actual_seq, places=None, msg=None, delta=None)` {#TestCase.assertSequenceAlmostEqual} + +An approximate equality assertion for ordered sequences. + +Fail if the two sequences are unequal as determined by their value +differences rounded to the given number of decimal places (default 7) and +comparing to zero, or by comparing that the difference between each value +in the two sequences is more than the given delta. + +Note that decimal places (from zero) are usually not the same as significant +digits (measured from the most signficant digit). + +If the two sequences compare equal then they will automatically compare +almost equal. + +##### Args: + + +* <b>`expected_seq`</b>: A sequence containing elements we are expecting. +* <b>`actual_seq`</b>: The sequence that we are testing. +* <b>`places`</b>: The number of decimal places to compare. +* <b>`msg`</b>: The message to be printed if the test fails. +* <b>`delta`</b>: The OK difference between compared values. + + - - - #### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual} @@ -559,6 +960,26 @@ which can be indexed, has a length, and has an equality operator. differences. +- - - + +#### `tf.test.TestCase.assertSequenceStartsWith(prefix, whole, msg=None)` {#TestCase.assertSequenceStartsWith} + +An equality assertion for the beginning of ordered sequences. + +If prefix is an empty sequence, it will raise an error unless whole is also +an empty sequence. + +If prefix is not a sequence, it will raise an error if the first element of +whole does not match. + +##### Args: + + +* <b>`prefix`</b>: A sequence expected at the beginning of the whole parameter. +* <b>`whole`</b>: The sequence in which to look for prefix. +* <b>`msg`</b>: Optional message to report on failure. + + - - - #### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual} @@ -610,6 +1031,51 @@ Assert that actual.startswith(expected_start) is True. * <b>`msg`</b>: Optional message to report on failure. +- - - + +#### `tf.test.TestCase.assertTotallyOrdered(*groups, **kwargs)` {#TestCase.assertTotallyOrdered} + +Asserts that total ordering has been implemented correctly. + +For example, say you have a class A that compares only on its attribute x. +Comparators other than __lt__ are omitted for brevity. + +class A(object): + def __init__(self, x, y): + self.x = x + self.y = y + + def __hash__(self): + return hash(self.x) + + def __lt__(self, other): + try: + return self.x < other.x + except AttributeError: + return NotImplemented + +assertTotallyOrdered will check that instances can be ordered correctly. +For example, + +self.assertTotallyOrdered( + [None], # None should come before everything else. + [1], # Integers sort earlier. + [A(1, 'a')], + [A(2, 'b')], # 2 is after 1. + [A(3, 'c'), A(3, 'd')], # The second argument is irrelevant. + [A(4, 'z')], + ['foo']) # Strings sort last. + +##### Args: + + +* <b>`*groups`</b>: A list of groups of elements. Each group of elements is a list + of objects that are equal. The elements in each group must be less than + the elements in the group after it. For example, these groups are + totally ordered: [None], [1], [2, 2], [3]. +* <b>`**kwargs`</b>: optional msg keyword argument can be passed. + + - - - #### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue} @@ -632,6 +1098,13 @@ A tuple-specific equality assertion. differences. +- - - + +#### `tf.test.TestCase.assertUrlEqual(a, b, msg=None)` {#TestCase.assertUrlEqual} + +Asserts that urls are equal, ignoring ordering of query params. + + - - - #### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_} @@ -693,9 +1166,9 @@ tearDown. - - - -#### `tf.test.TestCase.fail(msg=None)` {#TestCase.fail} +#### `tf.test.TestCase.fail(msg=None, prefix=None)` {#TestCase.fail} -Fail immediately, with the given message. +Fail immediately with the given message, optionally prefixed. - - - @@ -747,6 +1220,13 @@ Fail immediately, with the given message. +- - - + +#### `tf.test.TestCase.getRecordedProperties()` {#TestCase.getRecordedProperties} + +Return any properties that the user has recorded. + + - - - #### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir} @@ -769,6 +1249,20 @@ pollute each others environment. +- - - + +#### `tf.test.TestCase.recordProperty(property_name, property_value)` {#TestCase.recordProperty} + +Record an arbitrary property for later use. + +##### Args: + + +* <b>`property_name`</b>: str, name of property to record; must be a valid XML + attribute name +* <b>`property_value`</b>: value of property; must be valid XML attribute value + + - - - #### `tf.test.TestCase.run(result=None)` {#TestCase.run} @@ -794,11 +1288,18 @@ Hook method for setting up class fixture before running tests in the class. #### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription} -Returns a one-line description of the test, or None if no -description has been provided. +Format both the test method name and the first line of its docstring. -The default implementation of this method returns the first line of -the specified test method's docstring. +If no docstring is given, only returns the method name. + +This method overrides unittest.TestCase.shortDescription(), which +only returns the first line of the docstring, obscuring the name +of the test upon failure. + +##### Returns: + + +* <b>`desc`</b>: A short description of a test method. - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.linalg.LinearOperatorScaledIdentity.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.linalg.LinearOperatorScaledIdentity.md new file mode 100644 index 00000000000..9cef244fe44 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.linalg.LinearOperatorScaledIdentity.md @@ -0,0 +1,493 @@ +`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. + +This operator acts like a scaled [batch] identity matrix `A` with shape +`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a +batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is +a scaled version of the `N x N` identity matrix. + +`LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier` +(a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the +`multiplier` determines the scale for each batch member. + +```python +# Create a 2 x 2 scaled identity matrix. +operator = LinearOperatorIdentity(num_rows=2, multiplier=3.) + +operator.to_dense() +==> [[3., 0.] + [0., 3.]] + +operator.shape +==> [2, 2] + +operator.log_determinant() +==> 2 * Log[3] + +x = ... Shape [2, 4] Tensor +operator.apply(x) +==> 3 * x + +y = tf.random_normal(shape=[3, 2, 4]) +# Note that y.shape is compatible with operator.shape because operator.shape +# is broadcast to [3, 2, 2]. +x = operator.solve(y) +==> 3 * x + +# Create a 2-batch of 2x2 identity matrices +operator = LinearOperatorIdentity(num_rows=2, multiplier=5.) +operator.to_dense() +==> [[[5., 0.] + [0., 5.]], + [[5., 0.] + [0., 5.]]] + +x = ... Shape [2, 2, 3] +operator.apply(x) +==> 5 * x + +# Here the operator and x have different batch_shape, and are broadcast. +x = ... Shape [1, 2, 3] +operator.apply(x) +==> 5 * x +``` + +### Shape compatibility + +This operator acts on [batch] matrix with compatible shape. +`x` is a batch matrix with compatible shape for `apply` and `solve` if + +``` +operator.shape = [B1,...,Bb] + [N, N], with b >= 0 +x.shape = [C1,...,Cc] + [N, R], +and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] +``` + +### Performance + +* `operator.apply(x)` is `O(D1*...*Dd*N*R)` +* `operator.solve(x)` is `O(D1*...*Dd*N*R)` +* `operator.determinant()` is `O(D1*...*Dd)` + +#### Matrix property hints + +This `LinearOperator` is initialized with boolean flags of the form `is_X`, +for `X = non_singular, self_adjoint, positive_definite`. +These have the following meaning +* If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. +* If `is_X == False`, callers should expect the operator to not have `X`. +* If `is_X == None` (the default), callers should have no expectation either + way. +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.__init__(num_rows, multiplier, is_non_singular=None, is_self_adjoint=None, is_positive_definite=None, assert_proper_shapes=False, name='LinearOperatorScaledIdentity')` {#LinearOperatorScaledIdentity.__init__} + +Initialize a `LinearOperatorScaledIdentity`. + +The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which +determines the size of each identity matrix, and a `multiplier`, +which defines `dtype`, batch shape, and scale of each matrix. + +This operator is able to broadcast the leading (batch) dimensions. + +##### Args: + + +* <b>`num_rows`</b>: Scalar non-negative integer `Tensor`. Number of rows in the + corresponding identity matrix. +* <b>`multiplier`</b>: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). +* <b>`is_non_singular`</b>: Expect that this operator is non-singular. +* <b>`is_self_adjoint`</b>: Expect that this operator is equal to its hermitian + transpose. +* <b>`is_positive_definite`</b>: Expect that this operator is positive definite. +* <b>`assert_proper_shapes`</b>: Python `bool`. If `False`, only perform static + checks that initialization and method arguments have proper shape. + If `True`, and static checks are inconclusive, add asserts to the graph. +* <b>`name`</b>: A name for this `LinearOperator` + +##### Raises: + + +* <b>`ValueError`</b>: If `num_rows` is determined statically to be non-scalar, or + negative. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.add_to_tensor(mat, name='add_to_tensor')` {#LinearOperatorScaledIdentity.add_to_tensor} + +Add matrix represented by this operator to `mat`. Equiv to `I + mat`. + +##### Args: + + +* <b>`mat`</b>: `Tensor` with same `dtype` and shape broadcastable to `self`. +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + A `Tensor` with broadcast shape and same `dtype` as `self`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.apply(x, adjoint=False, name='apply')` {#LinearOperatorScaledIdentity.apply} + +Transform `x` with left multiplication: `x --> Ax`. + +##### Args: + + +* <b>`x`</b>: `Tensor` with compatible shape and same `dtype` as `self`. + See class docstring for definition of compatibility. +* <b>`adjoint`</b>: Python `bool`. If `True`, left multiply by the adjoint. +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + A `Tensor` with shape `[..., M, R]` and same `dtype` as `self`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.assert_non_singular(name='assert_non_singular')` {#LinearOperatorScaledIdentity.assert_non_singular} + +Returns an `Op` that asserts this operator is non singular. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.assert_positive_definite(name='assert_positive_definite')` {#LinearOperatorScaledIdentity.assert_positive_definite} + +Returns an `Op` that asserts this operator is positive definite. + +Here, positive definite means the real part of all eigenvalues is positive. +We do not require the operator to be self-adjoint. + +##### Args: + + +* <b>`name`</b>: A name to give this `Op`. + +##### Returns: + + An `Op` that asserts this operator is positive definite. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.assert_self_adjoint(name='assert_self_adjoint')` {#LinearOperatorScaledIdentity.assert_self_adjoint} + +Returns an `Op` that asserts this operator is self-adjoint. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.batch_shape` {#LinearOperatorScaledIdentity.batch_shape} + +`TensorShape` of batch dimensions of this `LinearOperator`. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns +`TensorShape([B1,...,Bb])`, equivalent to `A.get_shape()[:-2]` + +##### Returns: + + `TensorShape`, statically determined, may be undefined. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.batch_shape_dynamic(name='batch_shape_dynamic')` {#LinearOperatorScaledIdentity.batch_shape_dynamic} + +Shape of batch dimensions of this operator, determined at runtime. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding +`[B1,...,Bb]`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `int32` `Tensor` + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.determinant(name='det')` {#LinearOperatorScaledIdentity.determinant} + +Determinant for every batch member. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.domain_dimension` {#LinearOperatorScaledIdentity.domain_dimension} + +Dimension (in the sense of vector spaces) of the domain of this operator. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `N`. + +##### Returns: + + `Dimension` object. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.domain_dimension_dynamic(name='domain_dimension_dynamic')` {#LinearOperatorScaledIdentity.domain_dimension_dynamic} + +Dimension (in the sense of vector spaces) of the domain of this operator. + +Determined at runtime. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `N`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op`. + +##### Returns: + + `int32` `Tensor` + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.dtype` {#LinearOperatorScaledIdentity.dtype} + +The `DType` of `Tensor`s handled by this `LinearOperator`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.graph_parents` {#LinearOperatorScaledIdentity.graph_parents} + +List of graph dependencies of this `LinearOperator`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.is_non_singular` {#LinearOperatorScaledIdentity.is_non_singular} + + + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.is_positive_definite` {#LinearOperatorScaledIdentity.is_positive_definite} + + + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.is_self_adjoint` {#LinearOperatorScaledIdentity.is_self_adjoint} + + + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.log_abs_determinant(name='log_abs_det')` {#LinearOperatorScaledIdentity.log_abs_determinant} + +Log absolute value of determinant for every batch member. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.multiplier` {#LinearOperatorScaledIdentity.multiplier} + +The [batch] scalar `Tensor`, `c` in `cI`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.name` {#LinearOperatorScaledIdentity.name} + +Name prepended to all ops created by this `LinearOperator`. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.range_dimension` {#LinearOperatorScaledIdentity.range_dimension} + +Dimension (in the sense of vector spaces) of the range of this operator. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `M`. + +##### Returns: + + `Dimension` object. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.range_dimension_dynamic(name='range_dimension_dynamic')` {#LinearOperatorScaledIdentity.range_dimension_dynamic} + +Dimension (in the sense of vector spaces) of the range of this operator. + +Determined at runtime. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `M`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op`. + +##### Returns: + + `int32` `Tensor` + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.shape` {#LinearOperatorScaledIdentity.shape} + +`TensorShape` of this `LinearOperator`. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns +`TensorShape([B1,...,Bb, M, N])`, equivalent to `A.get_shape()`. + +##### Returns: + + `TensorShape`, statically determined, may be undefined. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.shape_dynamic(name='shape_dynamic')` {#LinearOperatorScaledIdentity.shape_dynamic} + +Shape of this `LinearOperator`, determined at runtime. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding +`[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `int32` `Tensor` + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.solve(rhs, adjoint=False, name='solve')` {#LinearOperatorScaledIdentity.solve} + +Solve `R` (batch) systems of equations exactly: `A X = rhs`. + +Examples: + +```python +# Create an operator acting like a 10 x 2 x 2 matrix. +operator = LinearOperator(...) +operator.shape # = 10 x 2 x 2 + +# Solve one linear system (R = 1) for every member of the length 10 batch. +RHS = ... # shape 10 x 2 x 1 +X = operator.solve(RHS) # shape 10 x 2 x 1 + +# Solve five linear systems (R = 5) for every member of the length 10 batch. +RHS = ... # shape 10 x 2 x 5 +X = operator.solve(RHS) +X[3, :, 2] # Solution to the linear system A[3, :, :] X = RHS[3, :, 2] +``` + +##### Args: + + +* <b>`rhs`</b>: `Tensor` with same `dtype` as this operator and compatible shape. + See class docstring for definition of compatibility. +* <b>`adjoint`</b>: Python `bool`. If `True`, solve the system involving the adjoint + of this `LinearOperator`. +* <b>`name`</b>: A name scope to use for ops added by this method. + +##### Returns: + + `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. + +##### Raises: + + +* <b>`ValueError`</b>: If self.is_non_singular is False. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.tensor_rank` {#LinearOperatorScaledIdentity.tensor_rank} + +Rank (in the sense of tensors) of matrix corresponding to this operator. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + Python integer, or None if the tensor rank is undefined. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.tensor_rank_dynamic(name='tensor_rank_dynamic')` {#LinearOperatorScaledIdentity.tensor_rank_dynamic} + +Rank (in the sense of tensors) of matrix corresponding to this operator. + +If this operator acts like the batch matrix `A` with +`A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. + +##### Args: + + +* <b>`name`</b>: A name for this `Op. + +##### Returns: + + `int32` `Tensor`, determined at runtime. + + +- - - + +#### `tf.contrib.linalg.LinearOperatorScaledIdentity.to_dense(name='to_dense')` {#LinearOperatorScaledIdentity.to_dense} + +Return a dense (batch) matrix representing this operator. + + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.training.bucket.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.training.bucket.md index 8ddb64eac22..2eb7e72705e 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.training.bucket.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.training.bucket.md @@ -47,8 +47,10 @@ operations that depend on fixed batch_size would fail. * <b>`tensors`</b>: The list or dictionary of tensors, representing a single element, to bucket. Nested lists are not supported. * <b>`which_bucket`</b>: An `int32` scalar Tensor taking a value in `[0, num_buckets)`. -* <b>`batch_size`</b>: The new batch size pulled from the queue - (python int or int32 scalar). +* <b>`batch_size`</b>: The new batch size pulled from the queue (all queues will have + the same size). If a list is passed in then each bucket will have a + different batch_size. + (python int, int32 scalar or iterable of integers of length num_buckets). * <b>`num_buckets`</b>: A python integer, the number of buckets. * <b>`num_threads`</b>: An integer. The number of threads enqueuing `tensors`. * <b>`capacity`</b>: An integer. The maximum number of minibatches in the top queue, @@ -80,5 +82,6 @@ operations that depend on fixed batch_size would fail. * <b>`ValueError`</b>: If the `shapes` are not specified, and cannot be - inferred from the elements of `tensors`. + inferred from the elements of `tensors` or if batch_size is a sequence + but it's length != num_buckets. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.scalar_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.scalar_summary.md deleted file mode 100644 index 3ffd9260c7b..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.scalar_summary.md +++ /dev/null @@ -1,22 +0,0 @@ -### `tf.scalar_summary(*args, **kwargs)` {#scalar_summary} - -Outputs a `Summary` protocol buffer with scalar values. (deprecated) - -THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. -Instructions for updating: -Please switch to tf.summary.scalar. Note that tf.summary.scalar uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on the scope they are created in. Also, passing a tensor or list of tags to a scalar summary op is no longer supported. - - The input `tags` and `values` must have the same shape. The generated - summary has a summary value for each tag-value pair in `tags` and `values`. - - Args: - tags: A `string` `Tensor`. Tags for the summaries. - values: A real numeric Tensor. Values for the summaries. - collections: Optional list of graph collections keys. The new summary op is - added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. - name: A name for the operation (optional). - - Returns: - A scalar `Tensor` of type `string`. The serialized `Summary` protocol - buffer. - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.batch_norm.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.batch_norm.md index 2b23d99de2c..386d3a357c2 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.batch_norm.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.batch_norm.md @@ -33,7 +33,8 @@ can have speed penalty, specially in distributed settings. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability. -* <b>`center`</b>: If True, subtract `beta`. If False, `beta` is ignored. +* <b>`center`</b>: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. * <b>`scale`</b>: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.summary.SummaryDescription.RegisterExtension.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.summary.SummaryDescription.RegisterExtension.md deleted file mode 100644 index 3cfd7103d7e..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.summary.SummaryDescription.RegisterExtension.md +++ /dev/null @@ -1,4 +0,0 @@ -#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension} - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.distributions.bijector.Affine.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.distributions.bijector.Affine.md index 76f0dd4557a..ee19f33d47f 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.distributions.bijector.Affine.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.distributions.bijector.Affine.md @@ -96,6 +96,7 @@ specified then `scale += IdentityMatrix`. Otherwise specifying a `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. * <b>`scale_perturb_factor`</b>: Numeric `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.histogram_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.histogram_summary.md deleted file mode 100644 index 570d7b712c6..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.histogram_summary.md +++ /dev/null @@ -1,26 +0,0 @@ -### `tf.histogram_summary(*args, **kwargs)` {#histogram_summary} - -Outputs a `Summary` protocol buffer with a histogram. (deprecated) - -THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. -Instructions for updating: -Please switch to tf.summary.histogram. Note that tf.summary.histogram uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on their scope. - - The generated - [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) - has one summary value containing a histogram for `values`. - - This op reports an `InvalidArgument` error if any value is not finite. - - Args: - tag: A `string` `Tensor`. 0-D. Tag to use for the summary value. - values: A real numeric `Tensor`. Any shape. Values to use to - build the histogram. - collections: Optional list of graph collections keys. The new summary op is - added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. - name: A name for the operation (optional). - - Returns: - A scalar `Tensor` of type `string`. The serialized `Summary` protocol - buffer. - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.total_variation.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.total_variation.md new file mode 100644 index 00000000000..03fec86c85e --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.total_variation.md @@ -0,0 +1,40 @@ +### `tf.image.total_variation(images, name=None)` {#total_variation} + +Calculate and return the total variation for one or more images. + +The total variation is the sum of the absolute differences for neighboring +pixel-values in the input images. This measures how much noise is in the +images. + +This can be used as a loss-function during optimization so as to suppress +noise in images. If you have a batch of images, then you should calculate +the scalar loss-value as the sum: +`loss = tf.reduce_sum(tf.image.total_variation(images))` + +This implements the anisotropic 2-D version of the formula described here: + +https://en.wikipedia.org/wiki/Total_variation_denoising + +##### Args: + + +* <b>`images`</b>: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. + + +* <b>`name`</b>: A name for the operation (optional). + +##### Raises: + + +* <b>`ValueError`</b>: if images.shape is not a 3-D or 4-D vector. + +##### Returns: + + The total variation of `images`. + + If `images` was 4-D, return a 1-D float Tensor of shape `[batch]` with the + total variation for each image in the batch. + If `images` was 3-D, return a scalar float with the total variation for + that image. + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.merge_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.merge_summary.md deleted file mode 100644 index ccb984f5abe..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.merge_summary.md +++ /dev/null @@ -1,27 +0,0 @@ -### `tf.merge_summary(*args, **kwargs)` {#merge_summary} - -Merges summaries. (deprecated) - -THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. -Instructions for updating: -Please switch to tf.summary.merge. - - This op creates a - [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) - protocol buffer that contains the union of all the values in the input - summaries. - - When the Op is run, it reports an `InvalidArgument` error if multiple values - in the summaries to merge use the same tag. - - Args: - inputs: A list of `string` `Tensor` objects containing serialized `Summary` - protocol buffers. - collections: Optional list of graph collections keys. The new summary op is - added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. - name: A name for the operation (optional). - - Returns: - A scalar `Tensor` of type `string`. The serialized `Summary` protocol - buffer resulting from the merging. - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.SummaryDescription.FromString.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.SummaryDescription.FromString.md deleted file mode 100644 index 24a3b3f10c3..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.summary.SummaryDescription.FromString.md +++ /dev/null @@ -1,4 +0,0 @@ -#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString} - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.initialize_all_tables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.initialize_all_tables.md index 8293a3c9449..4309820b84d 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.initialize_all_tables.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.initialize_all_tables.md @@ -1,6 +1,10 @@ -### `tf.initialize_all_tables(name='init_all_tables')` {#initialize_all_tables} +### `tf.initialize_all_tables(*args, **kwargs)` {#initialize_all_tables} -Returns an Op that initializes all tables of the default graph. +Returns an Op that initializes all tables of the default graph. (deprecated) + +THIS FUNCTION IS DEPRECATED. It will be removed after 2017-03-02. +Instructions for updating: +Use `tf.tables_initializer` instead. ##### Args: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.summary.TaggedRunMetadata.RegisterExtension.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.summary.TaggedRunMetadata.RegisterExtension.md deleted file mode 100644 index f2d0c042d77..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.summary.TaggedRunMetadata.RegisterExtension.md +++ /dev/null @@ -1,4 +0,0 @@ -#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension} - - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.layer_norm.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.layer_norm.md index c2d6c88d2e8..726426d9a90 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.layer_norm.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.layer_norm.md @@ -13,7 +13,8 @@ Can be used as a normalizer function for conv2d and fully_connected. * <b>`inputs`</b>: a tensor with 2 or more dimensions. The normalization occurs over all but the first dimension. -* <b>`center`</b>: If True, subtract `beta`. If False, `beta` is ignored. +* <b>`center`</b>: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. * <b>`scale`</b>: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md deleted file mode 100644 index e9bdda200f9..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.SummaryWriter.md +++ /dev/null @@ -1,207 +0,0 @@ - -- - - - -#### `tf.train.SummaryWriter.__init__(*args, **kwargs)` {#SummaryWriter.__init__} - -Creates a `SummaryWriter` and an event file. (deprecated) - -THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. -Instructions for updating: -Please switch to tf.summary.FileWriter. The interface and behavior is the same; this is just a rename. - - This class is deprecated, and should be replaced with tf.summary.FileWriter. - - On construction the summary writer creates a new event file in `logdir`. - This event file will contain `Event` protocol buffers constructed when you - call one of the following functions: `add_summary()`, `add_session_log()`, - `add_event()`, or `add_graph()`. - - If you pass a `Graph` to the constructor it is added to - the event file. (This is equivalent to calling `add_graph()` later). - - TensorBoard will pick the graph from the file and display it graphically so - you can interactively explore the graph you built. You will usually pass - the graph from the session in which you launched it: - - ```python - ...create a graph... - # Launch the graph in a session. - sess = tf.Session() - # Create a summary writer, add the 'graph' to the event file. - writer = tf.train.SummaryWriter(<some-directory>, sess.graph) - ``` - - The other arguments to the constructor control the asynchronous writes to - the event file: - - * `flush_secs`: How often, in seconds, to flush the added summaries - and events to disk. - * `max_queue`: Maximum number of summaries or events pending to be - written to disk before one of the 'add' calls block. - - Args: - logdir: A string. Directory where event file will be written. - graph: A `Graph` object, such as `sess.graph`. - max_queue: Integer. Size of the queue for pending events and summaries. - flush_secs: Number. How often, in seconds, to flush the - pending events and summaries to disk. - graph_def: DEPRECATED: Use the `graph` argument instead. - - -- - - - -#### `tf.train.SummaryWriter.add_event(event)` {#SummaryWriter.add_event} - -Adds an event to the event file. - -##### Args: - - -* <b>`event`</b>: An `Event` protocol buffer. - - -- - - - -#### `tf.train.SummaryWriter.add_graph(graph, global_step=None, graph_def=None)` {#SummaryWriter.add_graph} - -Adds a `Graph` to the event file. - -The graph described by the protocol buffer will be displayed by -TensorBoard. Most users pass a graph in the constructor instead. - -##### Args: - - -* <b>`graph`</b>: A `Graph` object, such as `sess.graph`. -* <b>`global_step`</b>: Number. Optional global step counter to record with the - graph. -* <b>`graph_def`</b>: DEPRECATED. Use the `graph` parameter instead. - -##### Raises: - - -* <b>`ValueError`</b>: If both graph and graph_def are passed to the method. - - -- - - - -#### `tf.train.SummaryWriter.add_meta_graph(meta_graph_def, global_step=None)` {#SummaryWriter.add_meta_graph} - -Adds a `MetaGraphDef` to the event file. - -The `MetaGraphDef` allows running the given graph via -`saver.import_meta_graph()`. - -##### Args: - - -* <b>`meta_graph_def`</b>: A `MetaGraphDef` object, often as retured by - `saver.export_meta_graph()`. -* <b>`global_step`</b>: Number. Optional global step counter to record with the - graph. - -##### Raises: - - -* <b>`TypeError`</b>: If both `meta_graph_def` is not an instance of `MetaGraphDef`. - - -- - - - -#### `tf.train.SummaryWriter.add_run_metadata(run_metadata, tag, global_step=None)` {#SummaryWriter.add_run_metadata} - -Adds a metadata information for a single session.run() call. - -##### Args: - - -* <b>`run_metadata`</b>: A `RunMetadata` protobuf object. -* <b>`tag`</b>: The tag name for this metadata. -* <b>`global_step`</b>: Number. Optional global step counter to record with the - StepStats. - -##### Raises: - - -* <b>`ValueError`</b>: If the provided tag was already used for this type of event. - - -- - - - -#### `tf.train.SummaryWriter.add_session_log(session_log, global_step=None)` {#SummaryWriter.add_session_log} - -Adds a `SessionLog` protocol buffer to the event file. - -This method wraps the provided session in an `Event` protocol buffer -and adds it to the event file. - -##### Args: - - -* <b>`session_log`</b>: A `SessionLog` protocol buffer. -* <b>`global_step`</b>: Number. Optional global step value to record with the - summary. - - -- - - - -#### `tf.train.SummaryWriter.add_summary(summary, global_step=None)` {#SummaryWriter.add_summary} - -Adds a `Summary` protocol buffer to the event file. - -This method wraps the provided summary in an `Event` protocol buffer -and adds it to the event file. - -You can pass the result of evaluating any summary op, using -[`Session.run()`](client.md#Session.run) or -[`Tensor.eval()`](framework.md#Tensor.eval), to this -function. Alternatively, you can pass a `tf.Summary` protocol -buffer that you populate with your own data. The latter is -commonly done to report evaluation results in event files. - -##### Args: - - -* <b>`summary`</b>: A `Summary` protocol buffer, optionally serialized as a string. -* <b>`global_step`</b>: Number. Optional global step value to record with the - summary. - - -- - - - -#### `tf.train.SummaryWriter.close()` {#SummaryWriter.close} - -Flushes the event file to disk and close the file. - -Call this method when you do not need the summary writer anymore. - - -- - - - -#### `tf.train.SummaryWriter.flush()` {#SummaryWriter.flush} - -Flushes the event file to disk. - -Call this method to make sure that all pending events have been written to -disk. - - -- - - - -#### `tf.train.SummaryWriter.get_logdir()` {#SummaryWriter.get_logdir} - -Returns the directory where event file will be written. - - -- - - - -#### `tf.train.SummaryWriter.reopen()` {#SummaryWriter.reopen} - -Reopens the EventFileWriter. - -Can be called after `close()` to add more events in the same directory. -The events will go into a new events file. - -Does nothing if the EventFileWriter was not closed. - - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.training.bucket_by_sequence_length.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.training.bucket_by_sequence_length.md index 85b5f08c3ca..69dc1384e4b 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.training.bucket_by_sequence_length.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.training.bucket_by_sequence_length.md @@ -13,8 +13,10 @@ bucket the given `input_length` belongs to. See the documentation for * <b>`input_length`</b>: `int32` scalar `Tensor`, the sequence length of tensors. * <b>`tensors`</b>: The list or dictionary of tensors, representing a single element, to bucket. Nested lists are not supported. -* <b>`batch_size`</b>: The new batch size pulled from the queue - (python int or int32 scalar). +* <b>`batch_size`</b>: The new batch size pulled from the queue (all queues will have + the same size). If a list is passed in then each bucket will have a + different batch_size. + (python int, int32 scalar or iterable of integers of length num_buckets). * <b>`bucket_boundaries`</b>: int list, increasing non-negative numbers. The edges of the buckets to use when bucketing tensors. Two extra buckets are created, one for `input_length < bucket_boundaries[0]` and @@ -49,5 +51,6 @@ bucket the given `input_length` belongs to. See the documentation for * <b>`TypeError`</b>: if `bucket_boundaries` is not a list of python integers. * <b>`ValueError`</b>: if `bucket_boundaries` is empty or contains non-increasing - values. + values or if batch_size is a list and it's length doesn't equal the number + of buckets. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.audio_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.audio_summary.md deleted file mode 100644 index c5830ab5504..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.audio_summary.md +++ /dev/null @@ -1,37 +0,0 @@ -### `tf.audio_summary(*args, **kwargs)` {#audio_summary} - -Outputs a `Summary` protocol buffer with audio. (deprecated) - -THIS FUNCTION IS DEPRECATED. It will be removed after 2016-11-30. -Instructions for updating: -Please switch to tf.summary.audio. Note that tf.summary.histogram uses the node name instead of the tag. This means that TensorFlow will automatically de-duplicate summary names based on the scope they are created in. - - The summary has up to `max_outputs` summary values containing audio. The - audio is built from `tensor` which must be 3-D with shape `[batch_size, - frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are - assumed to be in the range of `[-1.0, 1.0]` with a sample rate of - `sample_rate`. - - The `tag` argument is a scalar `Tensor` of type `string`. It is used to - build the `tag` of the summary values: - - * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. - * If `max_outputs` is greater than 1, the summary value tags are - generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. - - Args: - tag: A scalar `Tensor` of type `string`. Used to build the `tag` - of the summary values. - tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` - or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`. - sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the - signal in hertz. - max_outputs: Max number of batch elements to generate audio for. - collections: Optional list of ops.GraphKeys. The collections to add the - summary to. Defaults to [ops.GraphKeys.SUMMARIES] - name: A name for the operation (optional). - - Returns: - A scalar `Tensor` of type `string`. The serialized `Summary` protocol - buffer. - diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.distributions.softplus_inverse.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.distributions.softplus_inverse.md new file mode 100644 index 00000000000..6f97b1f9594 --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.distributions.softplus_inverse.md @@ -0,0 +1,20 @@ +### `tf.contrib.distributions.softplus_inverse(x, name=None)` {#softplus_inverse} + +Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). + +Mathematically this op is equivalent to: + +```none +softplus_inverse = log(exp(x) - 1.) +``` + +##### Args: + + +* <b>`x`</b>: `Tensor`. Non-negative (not enforced), floating-point. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + `Tensor`. Has the same type/shape as input `x`. + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.summary.TaggedRunMetadata.FromString.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.summary.TaggedRunMetadata.FromString.md deleted file mode 100644 index 613f4ebd73d..00000000000 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.summary.TaggedRunMetadata.FromString.md +++ /dev/null @@ -1,4 +0,0 @@ -#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString} - - - diff --git a/tensorflow/g3doc/api_docs/python/image.md b/tensorflow/g3doc/api_docs/python/image.md index d218ba024a5..baef42db057 100644 --- a/tensorflow/g3doc/api_docs/python/image.md +++ b/tensorflow/g3doc/api_docs/python/image.md @@ -1474,3 +1474,49 @@ false and no bounding boxes are supplied, an error is raised. Provide as input to `tf.image.draw_bounding_boxes`. + +## Denoising + +- - - + +### `tf.image.total_variation(images, name=None)` {#total_variation} + +Calculate and return the total variation for one or more images. + +The total variation is the sum of the absolute differences for neighboring +pixel-values in the input images. This measures how much noise is in the +images. + +This can be used as a loss-function during optimization so as to suppress +noise in images. If you have a batch of images, then you should calculate +the scalar loss-value as the sum: +`loss = tf.reduce_sum(tf.image.total_variation(images))` + +This implements the anisotropic 2-D version of the formula described here: + +https://en.wikipedia.org/wiki/Total_variation_denoising + +##### Args: + + +* <b>`images`</b>: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. + + +* <b>`name`</b>: A name for the operation (optional). + +##### Raises: + + +* <b>`ValueError`</b>: if images.shape is not a 3-D or 4-D vector. + +##### Returns: + + The total variation of `images`. + + If `images` was 4-D, return a 1-D float Tensor of shape `[batch]` with the + total variation for each image in the batch. + If `images` was 3-D, return a scalar float with the total variation for + that image. + + diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 449e582d190..cc3dc0a0e57 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -119,6 +119,7 @@ * [`scatter_sub`](../../api_docs/python/state_ops.md#scatter_sub) * [`scatter_update`](../../api_docs/python/state_ops.md#scatter_update) * [`sparse_mask`](../../api_docs/python/state_ops.md#sparse_mask) + * [`tables_initializer`](../../api_docs/python/state_ops.md#tables_initializer) * [`trainable_variables`](../../api_docs/python/state_ops.md#trainable_variables) * [`truncated_normal_initializer`](../../api_docs/python/state_ops.md#truncated_normal_initializer) * [`uniform_unit_scaling_initializer`](../../api_docs/python/state_ops.md#uniform_unit_scaling_initializer) @@ -415,6 +416,7 @@ * [`rgb_to_hsv`](../../api_docs/python/image.md#rgb_to_hsv) * [`rot90`](../../api_docs/python/image.md#rot90) * [`sample_distorted_bounding_box`](../../api_docs/python/image.md#sample_distorted_bounding_box) + * [`total_variation`](../../api_docs/python/image.md#total_variation) * [`transpose_image`](../../api_docs/python/image.md#transpose_image) * **[Sparse Tensors](../../api_docs/python/sparse_ops.md)**: @@ -769,6 +771,7 @@ * [`Poisson`](../../api_docs/python/contrib.distributions.md#Poisson) * [`QuantizedDistribution`](../../api_docs/python/contrib.distributions.md#QuantizedDistribution) * [`RegisterKL`](../../api_docs/python/contrib.distributions.md#RegisterKL) + * [`softplus_inverse`](../../api_docs/python/contrib.distributions.md#softplus_inverse) * [`StudentT`](../../api_docs/python/contrib.distributions.md#StudentT) * [`StudentTWithAbsDfSoftplusSigma`](../../api_docs/python/contrib.distributions.md#StudentTWithAbsDfSoftplusSigma) * [`TransformedDistribution`](../../api_docs/python/contrib.distributions.md#TransformedDistribution) @@ -1030,6 +1033,7 @@ * [`LinearOperatorDiag`](../../api_docs/python/contrib.linalg.md#LinearOperatorDiag) * [`LinearOperatorIdentity`](../../api_docs/python/contrib.linalg.md#LinearOperatorIdentity) * [`LinearOperatorMatrix`](../../api_docs/python/contrib.linalg.md#LinearOperatorMatrix) + * [`LinearOperatorScaledIdentity`](../../api_docs/python/contrib.linalg.md#LinearOperatorScaledIdentity) * [`LinearOperatorTriL`](../../api_docs/python/contrib.linalg.md#LinearOperatorTriL) * **[Losses (contrib)](../../api_docs/python/contrib.losses.md)**: diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index 9c2e6da666d..2db192fddd8 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -3472,7 +3472,28 @@ The `Graph` that contains the values, indices, and shape tensors. - - - -### `tf.initialize_all_tables(name='init_all_tables')` {#initialize_all_tables} +### `tf.initialize_all_tables(*args, **kwargs)` {#initialize_all_tables} + +Returns an Op that initializes all tables of the default graph. (deprecated) + +THIS FUNCTION IS DEPRECATED. It will be removed after 2017-03-02. +Instructions for updating: +Use `tf.tables_initializer` instead. + +##### Args: + + +* <b>`name`</b>: Optional name for the initialization op. + +##### Returns: + + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + + +- - - + +### `tf.tables_initializer(name='init_all_tables')` {#tables_initializer} Returns an Op that initializes all tables of the default graph. diff --git a/tensorflow/g3doc/api_docs/python/summary.md b/tensorflow/g3doc/api_docs/python/summary.md index be029f42906..8d344036dbc 100644 --- a/tensorflow/g3doc/api_docs/python/summary.md +++ b/tensorflow/g3doc/api_docs/python/summary.md @@ -485,187 +485,6 @@ metadata is stored in its NodeDef. This method retrieves the description. ### `class tf.summary.SummaryDescription` {#SummaryDescription} -- - - - -#### `tf.summary.SummaryDescription.ByteSize()` {#SummaryDescription.ByteSize} - - - - -- - - - -#### `tf.summary.SummaryDescription.Clear()` {#SummaryDescription.Clear} - - - - -- - - - -#### `tf.summary.SummaryDescription.ClearExtension(extension_handle)` {#SummaryDescription.ClearExtension} - - - - -- - - - -#### `tf.summary.SummaryDescription.ClearField(field_name)` {#SummaryDescription.ClearField} - - - - -- - - - -#### `tf.summary.SummaryDescription.CopyFrom(other_msg)` {#SummaryDescription.CopyFrom} - -Copies the content of the specified message into the current message. - -The method clears the current message and then merges the specified -message using MergeFrom. - -##### Args: - - -* <b>`other_msg`</b>: Message to copy into the current one. - - -- - - - -#### `tf.summary.SummaryDescription.DiscardUnknownFields()` {#SummaryDescription.DiscardUnknownFields} - - - - -- - - - -#### `tf.summary.SummaryDescription.FindInitializationErrors()` {#SummaryDescription.FindInitializationErrors} - -Finds required fields which are not initialized. - -##### Returns: - - A list of strings. Each string is a path to an uninitialized field from - the top-level message, e.g. "foo.bar[5].baz". - - -- - - - -#### `tf.summary.SummaryDescription.FromString(s)` {#SummaryDescription.FromString} - - - - -- - - - -#### `tf.summary.SummaryDescription.HasExtension(extension_handle)` {#SummaryDescription.HasExtension} - - - - -- - - - -#### `tf.summary.SummaryDescription.HasField(field_name)` {#SummaryDescription.HasField} - - - - -- - - - -#### `tf.summary.SummaryDescription.IsInitialized(errors=None)` {#SummaryDescription.IsInitialized} - -Checks if all required fields of a message are set. - -##### Args: - - -* <b>`errors`</b>: A list which, if provided, will be populated with the field - paths of all missing required fields. - -##### Returns: - - True iff the specified message has all required fields set. - - -- - - - -#### `tf.summary.SummaryDescription.ListFields()` {#SummaryDescription.ListFields} - - - - -- - - - -#### `tf.summary.SummaryDescription.MergeFrom(msg)` {#SummaryDescription.MergeFrom} - - - - -- - - - -#### `tf.summary.SummaryDescription.MergeFromString(serialized)` {#SummaryDescription.MergeFromString} - - - - -- - - - -#### `tf.summary.SummaryDescription.ParseFromString(serialized)` {#SummaryDescription.ParseFromString} - -Parse serialized protocol buffer data into this message. - -Like MergeFromString(), except we clear the object first and -do not return the value that MergeFromString returns. - - -- - - - -#### `tf.summary.SummaryDescription.RegisterExtension(extension_handle)` {#SummaryDescription.RegisterExtension} - - - - -- - - - -#### `tf.summary.SummaryDescription.SerializePartialToString()` {#SummaryDescription.SerializePartialToString} - - - - -- - - - -#### `tf.summary.SummaryDescription.SerializeToString()` {#SummaryDescription.SerializeToString} - - - - -- - - - -#### `tf.summary.SummaryDescription.SetInParent()` {#SummaryDescription.SetInParent} - -Sets the _cached_byte_size_dirty bit to true, -and propagates this to our listener iff this was a state change. - - -- - - - -#### `tf.summary.SummaryDescription.WhichOneof(oneof_name)` {#SummaryDescription.WhichOneof} - -Returns the name of the currently set field inside a oneof, or None. - - -- - - - -#### `tf.summary.SummaryDescription.__deepcopy__(memo=None)` {#SummaryDescription.__deepcopy__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__eq__(other)` {#SummaryDescription.__eq__} - - - - - - - #### `tf.summary.SummaryDescription.__getstate__()` {#SummaryDescription.__getstate__} @@ -673,249 +492,12 @@ Returns the name of the currently set field inside a oneof, or None. Support the pickle protocol. -- - - - -#### `tf.summary.SummaryDescription.__hash__()` {#SummaryDescription.__hash__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__init__(**kwargs)` {#SummaryDescription.__init__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__ne__(other_msg)` {#SummaryDescription.__ne__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__repr__()` {#SummaryDescription.__repr__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__setstate__(state)` {#SummaryDescription.__setstate__} - -Support the pickle protocol. - - -- - - - -#### `tf.summary.SummaryDescription.__str__()` {#SummaryDescription.__str__} - - - - -- - - - -#### `tf.summary.SummaryDescription.__unicode__()` {#SummaryDescription.__unicode__} - - - - -- - - - -#### `tf.summary.SummaryDescription.type_hint` {#SummaryDescription.type_hint} - -Magic attribute generated for "type_hint" proto field. - - - - - ### `class tf.summary.TaggedRunMetadata` {#TaggedRunMetadata} -- - - - -#### `tf.summary.TaggedRunMetadata.ByteSize()` {#TaggedRunMetadata.ByteSize} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.Clear()` {#TaggedRunMetadata.Clear} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.ClearExtension(extension_handle)` {#TaggedRunMetadata.ClearExtension} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.ClearField(field_name)` {#TaggedRunMetadata.ClearField} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.CopyFrom(other_msg)` {#TaggedRunMetadata.CopyFrom} - -Copies the content of the specified message into the current message. - -The method clears the current message and then merges the specified -message using MergeFrom. - -##### Args: - - -* <b>`other_msg`</b>: Message to copy into the current one. - - -- - - - -#### `tf.summary.TaggedRunMetadata.DiscardUnknownFields()` {#TaggedRunMetadata.DiscardUnknownFields} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.FindInitializationErrors()` {#TaggedRunMetadata.FindInitializationErrors} - -Finds required fields which are not initialized. - -##### Returns: - - A list of strings. Each string is a path to an uninitialized field from - the top-level message, e.g. "foo.bar[5].baz". - - -- - - - -#### `tf.summary.TaggedRunMetadata.FromString(s)` {#TaggedRunMetadata.FromString} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.HasExtension(extension_handle)` {#TaggedRunMetadata.HasExtension} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.HasField(field_name)` {#TaggedRunMetadata.HasField} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.IsInitialized(errors=None)` {#TaggedRunMetadata.IsInitialized} - -Checks if all required fields of a message are set. - -##### Args: - - -* <b>`errors`</b>: A list which, if provided, will be populated with the field - paths of all missing required fields. - -##### Returns: - - True iff the specified message has all required fields set. - - -- - - - -#### `tf.summary.TaggedRunMetadata.ListFields()` {#TaggedRunMetadata.ListFields} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.MergeFrom(msg)` {#TaggedRunMetadata.MergeFrom} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.MergeFromString(serialized)` {#TaggedRunMetadata.MergeFromString} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.ParseFromString(serialized)` {#TaggedRunMetadata.ParseFromString} - -Parse serialized protocol buffer data into this message. - -Like MergeFromString(), except we clear the object first and -do not return the value that MergeFromString returns. - - -- - - - -#### `tf.summary.TaggedRunMetadata.RegisterExtension(extension_handle)` {#TaggedRunMetadata.RegisterExtension} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.SerializePartialToString()` {#TaggedRunMetadata.SerializePartialToString} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.SerializeToString()` {#TaggedRunMetadata.SerializeToString} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.SetInParent()` {#TaggedRunMetadata.SetInParent} - -Sets the _cached_byte_size_dirty bit to true, -and propagates this to our listener iff this was a state change. - - -- - - - -#### `tf.summary.TaggedRunMetadata.WhichOneof(oneof_name)` {#TaggedRunMetadata.WhichOneof} - -Returns the name of the currently set field inside a oneof, or None. - - -- - - - -#### `tf.summary.TaggedRunMetadata.__deepcopy__(memo=None)` {#TaggedRunMetadata.__deepcopy__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__eq__(other)` {#TaggedRunMetadata.__eq__} - - - - - - - #### `tf.summary.TaggedRunMetadata.__getstate__()` {#TaggedRunMetadata.__getstate__} @@ -923,67 +505,4 @@ Returns the name of the currently set field inside a oneof, or None. Support the pickle protocol. -- - - - -#### `tf.summary.TaggedRunMetadata.__hash__()` {#TaggedRunMetadata.__hash__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__init__(**kwargs)` {#TaggedRunMetadata.__init__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__ne__(other_msg)` {#TaggedRunMetadata.__ne__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__repr__()` {#TaggedRunMetadata.__repr__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__setstate__(state)` {#TaggedRunMetadata.__setstate__} - -Support the pickle protocol. - - -- - - - -#### `tf.summary.TaggedRunMetadata.__str__()` {#TaggedRunMetadata.__str__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.__unicode__()` {#TaggedRunMetadata.__unicode__} - - - - -- - - - -#### `tf.summary.TaggedRunMetadata.run_metadata` {#TaggedRunMetadata.run_metadata} - -Magic attribute generated for "run_metadata" proto field. - - -- - - - -#### `tf.summary.TaggedRunMetadata.tag` {#TaggedRunMetadata.tag} - -Magic attribute generated for "tag" proto field. - - diff --git a/tensorflow/g3doc/api_docs/python/test.md b/tensorflow/g3doc/api_docs/python/test.md index 265e4028d0f..c95f9718894 100644 --- a/tensorflow/g3doc/api_docs/python/test.md +++ b/tensorflow/g3doc/api_docs/python/test.md @@ -213,6 +213,125 @@ Checks that for all elements of farray1 and farray2 * <b>`err`</b>: a float value. +- - - + +#### `tf.test.TestCase.assertBetween(value, minv, maxv, msg=None)` {#TestCase.assertBetween} + +Asserts that value is between minv and maxv (inclusive). + + +- - - + +#### `tf.test.TestCase.assertCommandFails(command, regexes, env=None, close_fds=True, msg=None)` {#TestCase.assertCommandFails} + +Asserts a shell command fails and the error matches a regex in a list. + +##### Args: + + +* <b>`command`</b>: List or string representing the command to run. +* <b>`regexes`</b>: the list of regular expression strings. +* <b>`env`</b>: Dictionary of environment variable settings. +* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after + forking. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertCommandSucceeds(command, regexes=('',), env=None, close_fds=True, msg=None)` {#TestCase.assertCommandSucceeds} + +Asserts that a shell command succeeds (i.e. exits with code 0). + +##### Args: + + +* <b>`command`</b>: List or string representing the command to run. +* <b>`regexes`</b>: List of regular expression byte strings that match success. +* <b>`env`</b>: Dictionary of environment variable settings. +* <b>`close_fds`</b>: Whether or not to close all open fd's in the child after + forking. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertContainsExactSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsExactSubsequence} + +Assert that "container" contains "subsequence" as an exact subsequence. + +Asserts that "container" contains all the elements of "subsequence", in +order, and without other elements interspersed. For example, [1, 2, 3] is an +exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0]. + +##### Args: + + +* <b>`container`</b>: the list we're testing for subsequence inclusion. +* <b>`subsequence`</b>: the list we hope will be an exact subsequence of container. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertContainsInOrder(strings, target, msg=None)` {#TestCase.assertContainsInOrder} + +Asserts that the strings provided are found in the target in order. + +This may be useful for checking HTML output. + +##### Args: + + +* <b>`strings`</b>: A list of strings, such as [ 'fox', 'dog' ] +* <b>`target`</b>: A target string in which to look for the strings, such as + 'The quick brown fox jumped over the lazy dog'. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertContainsSubsequence(container, subsequence, msg=None)` {#TestCase.assertContainsSubsequence} + +Assert that "container" contains "subsequence" as a subsequence. + +Asserts that "container" contains all the elements of "subsequence", in +order, but possibly with other elements interspersed. For example, [1, 2, 3] +is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0]. + +##### Args: + + +* <b>`container`</b>: the list we're testing for subsequence inclusion. +* <b>`subsequence`</b>: the list we hope will be a subsequence of container. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertContainsSubset(expected_subset, actual_set, msg=None)` {#TestCase.assertContainsSubset} + +Checks whether actual iterable is a superset of expected iterable. + + +- - - + +#### `tf.test.TestCase.assertCountEqual(*args, **kwargs)` {#TestCase.assertCountEqual} + +An unordered sequence specific comparison. + +Equivalent to assertItemsEqual(). This method is a compatibility layer +for Python 3k, since 2to3 does not convert assertItemsEqual() calls into +assertCountEqual() calls. + +##### Args: + + +* <b>`expected_seq`</b>: A sequence containing elements we are expecting. +* <b>`actual_seq`</b>: The sequence that we are testing. +* <b>`msg`</b>: The message to be printed if the test fails. + + - - - #### `tf.test.TestCase.assertDeviceEqual(device1, device2)` {#TestCase.assertDeviceEqual} @@ -235,9 +354,48 @@ Checks whether actual is a superset of expected. - - - -#### `tf.test.TestCase.assertDictEqual(d1, d2, msg=None)` {#TestCase.assertDictEqual} +#### `tf.test.TestCase.assertDictEqual(a, b, msg=None)` {#TestCase.assertDictEqual} + +Raises AssertionError if a and b are not equal dictionaries. + +##### Args: +* <b>`a`</b>: A dict, the expected value. +* <b>`b`</b>: A dict, the actual value. +* <b>`msg`</b>: An optional str, the associated message. + +##### Raises: + + +* <b>`AssertionError`</b>: if the dictionaries are not equal. + + +- - - + +#### `tf.test.TestCase.assertEmpty(container, msg=None)` {#TestCase.assertEmpty} + +Assert that an object has zero length. + +##### Args: + + +* <b>`container`</b>: Anything that implements the collections.Sized interface. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertEndsWith(actual, expected_end, msg=None)` {#TestCase.assertEndsWith} + +Assert that actual.endswith(expected_end) is True. + +##### Args: + + +* <b>`actual`</b>: str +* <b>`expected_end`</b>: str +* <b>`msg`</b>: Optional message to report on failure. - - - @@ -322,10 +480,11 @@ Included for symmetry with assertIsNone. - - - -#### `tf.test.TestCase.assertItemsEqual(expected_seq, actual_seq, msg=None)` {#TestCase.assertItemsEqual} +#### `tf.test.TestCase.assertItemsEqual(*args, **kwargs)` {#TestCase.assertItemsEqual} -An unordered sequence specific comparison. It asserts that -actual_seq and expected_seq have the same element counts. +An unordered sequence specific comparison. + +It asserts that actual_seq and expected_seq have the same element counts. Equivalent to:: self.assertEqual(Counter(iter(actual_seq)), @@ -338,6 +497,30 @@ Asserts that each element has the same count in both sequences. - [0, 1, 1] and [1, 0, 1] compare equal. - [0, 0, 1] and [0, 1] compare unequal. +##### Args: + + +* <b>`expected_seq`</b>: A sequence containing elements we are expecting. +* <b>`actual_seq`</b>: The sequence that we are testing. +* <b>`msg`</b>: The message to be printed if the test fails. + + +- - - + +#### `tf.test.TestCase.assertJsonEqual(first, second, msg=None)` {#TestCase.assertJsonEqual} + +Asserts that the JSON objects defined in two strings are equal. + +A summary of the differences will be included in the failure message +using assertSameStructure. + +##### Args: + + +* <b>`first`</b>: A string contining JSON to decode and compare to second. +* <b>`second`</b>: A string contining JSON to decode and compare to first. +* <b>`msg`</b>: Additional text to include in the failure message. + - - - @@ -407,6 +590,13 @@ if not. * <b>`msg`</b>: An optional string message to append to the failure message. +- - - + +#### `tf.test.TestCase.assertNoCommonElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertNoCommonElements} + +Checks whether actual iterable and expected iterable are disjoint. + + - - - #### `tf.test.TestCase.assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)` {#TestCase.assertNotAlmostEqual} @@ -437,6 +627,33 @@ as significant digits (measured from the most signficant digit). Objects that are equal automatically fail. +- - - + +#### `tf.test.TestCase.assertNotEmpty(container, msg=None)` {#TestCase.assertNotEmpty} + +Assert that an object has non-zero length. + +##### Args: + + +* <b>`container`</b>: Anything that implements the collections.Sized interface. +* <b>`msg`</b>: Optional message to report on failure. + + +- - - + +#### `tf.test.TestCase.assertNotEndsWith(actual, unexpected_end, msg=None)` {#TestCase.assertNotEndsWith} + +Assert that actual.endswith(unexpected_end) is False. + +##### Args: + + +* <b>`actual`</b>: str +* <b>`unexpected_end`</b>: str +* <b>`msg`</b>: Optional message to report on failure. + + - - - #### `tf.test.TestCase.assertNotEqual(first, second, msg=None)` {#TestCase.assertNotEqual} @@ -474,6 +691,20 @@ Included for symmetry with assertIsInstance. Fail the test if the text matches the regular expression. +- - - + +#### `tf.test.TestCase.assertNotStartsWith(actual, unexpected_start, msg=None)` {#TestCase.assertNotStartsWith} + +Assert that actual.startswith(unexpected_start) is False. + +##### Args: + + +* <b>`actual`</b>: str +* <b>`unexpected_start`</b>: str +* <b>`msg`</b>: Optional message to report on failure. + + - - - #### `tf.test.TestCase.assertProtoEquals(expected_message_maybe_ascii, message)` {#TestCase.assertProtoEquals} @@ -548,6 +779,38 @@ Asserts that the message in a raised exception matches a regexp. * <b>`kwargs`</b>: Extra kwargs. +- - - + +#### `tf.test.TestCase.assertRaisesWithLiteralMatch(expected_exception, expected_exception_message, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithLiteralMatch} + +Asserts that the message in a raised exception equals the given string. + +Unlike assertRaisesRegexp, this method takes a literal string, not +a regular expression. + +with self.assertRaisesWithLiteralMatch(ExType, 'message'): + DoSomething() + +##### Args: + + +* <b>`expected_exception`</b>: Exception class expected to be raised. +* <b>`expected_exception_message`</b>: String message expected in the raised + exception. For a raise exception e, expected_exception_message must + equal str(e). +* <b>`callable_obj`</b>: Function to be called, or None to return a context. +* <b>`args`</b>: Extra args. +* <b>`kwargs`</b>: Extra kwargs. + +##### Returns: + + A context manager if callable_obj is None. Otherwise, None. + +##### Raises: + + self.failureException if callable_obj does not raise a macthing exception. + + - - - #### `tf.test.TestCase.assertRaisesWithPredicateMatch(exception_type, expected_err_re_or_predicate)` {#TestCase.assertRaisesWithPredicateMatch} @@ -572,6 +835,71 @@ predicate search. exception. +- - - + +#### `tf.test.TestCase.assertRaisesWithRegexpMatch(expected_exception, expected_regexp, callable_obj=None, *args, **kwargs)` {#TestCase.assertRaisesWithRegexpMatch} + +Asserts that the message in a raised exception matches the given regexp. + +This is just a wrapper around assertRaisesRegexp. Please use +assertRaisesRegexp instead of assertRaisesWithRegexpMatch. + +##### Args: + + +* <b>`expected_exception`</b>: Exception class expected to be raised. +* <b>`expected_regexp`</b>: Regexp (re pattern object or string) expected to be + found in error message. +* <b>`callable_obj`</b>: Function to be called, or None to return a context. +* <b>`args`</b>: Extra args. +* <b>`kwargs`</b>: Extra keyword args. + +##### Returns: + + A context manager if callable_obj is None. Otherwise, None. + +##### Raises: + + self.failureException if callable_obj does not raise a macthing exception. + + +- - - + +#### `tf.test.TestCase.assertRegexMatch(actual_str, regexes, message=None)` {#TestCase.assertRegexMatch} + +Asserts that at least one regex in regexes matches str. + + If possible you should use assertRegexpMatches, which is a simpler + version of this method. assertRegexpMatches takes a single regular + expression (a string or re compiled object) instead of a list. + + Notes: + 1. This function uses substring matching, i.e. the matching + succeeds if *any* substring of the error message matches *any* + regex in the list. This is more convenient for the user than + full-string matching. + + 2. If regexes is the empty list, the matching will always fail. + + 3. Use regexes=[''] for a regex that will always pass. + + 4. '.' matches any single character *except* the newline. To + match any character, use '(.| +)'. + + 5. '^' matches the beginning of each line, not just the beginning + of the string. Similarly, '$' matches the end of each line. + + 6. An exception will be thrown if regexes contains an invalid + regex. + + Args: + actual_str: The string we try to match with the items in regexes. + regexes: The regular expressions we want to match against str. + See "Notes" above for detailed notes on how this is interpreted. + message: The message to be printed if the test fails. + + - - - #### `tf.test.TestCase.assertRegexpMatches(text, expected_regexp, msg=None)` {#TestCase.assertRegexpMatches} @@ -579,6 +907,79 @@ predicate search. Fail the test unless the text matches the regular expression. +- - - + +#### `tf.test.TestCase.assertSameElements(expected_seq, actual_seq, msg=None)` {#TestCase.assertSameElements} + +Assert that two sequences have the same elements (in any order). + +This method, unlike assertItemsEqual, doesn't care about any +duplicates in the expected and actual sequences. + + >> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1]) + # Doesn't raise an AssertionError + +If possible, you should use assertItemsEqual instead of +assertSameElements. + +##### Args: + + +* <b>`expected_seq`</b>: A sequence containing elements we are expecting. +* <b>`actual_seq`</b>: The sequence that we are testing. +* <b>`msg`</b>: The message to be printed if the test fails. + + +- - - + +#### `tf.test.TestCase.assertSameStructure(a, b, aname='a', bname='b', msg=None)` {#TestCase.assertSameStructure} + +Asserts that two values contain the same structural content. + +The two arguments should be data trees consisting of trees of dicts and +lists. They will be deeply compared by walking into the contents of dicts +and lists; other items will be compared using the == operator. +If the two structures differ in content, the failure message will indicate +the location within the structures where the first difference is found. +This may be helpful when comparing large structures. + +##### Args: + + +* <b>`a`</b>: The first structure to compare. +* <b>`b`</b>: The second structure to compare. +* <b>`aname`</b>: Variable name to use for the first structure in assertion messages. +* <b>`bname`</b>: Variable name to use for the second structure. +* <b>`msg`</b>: Additional text to include in the failure message. + + +- - - + +#### `tf.test.TestCase.assertSequenceAlmostEqual(expected_seq, actual_seq, places=None, msg=None, delta=None)` {#TestCase.assertSequenceAlmostEqual} + +An approximate equality assertion for ordered sequences. + +Fail if the two sequences are unequal as determined by their value +differences rounded to the given number of decimal places (default 7) and +comparing to zero, or by comparing that the difference between each value +in the two sequences is more than the given delta. + +Note that decimal places (from zero) are usually not the same as significant +digits (measured from the most signficant digit). + +If the two sequences compare equal then they will automatically compare +almost equal. + +##### Args: + + +* <b>`expected_seq`</b>: A sequence containing elements we are expecting. +* <b>`actual_seq`</b>: The sequence that we are testing. +* <b>`places`</b>: The number of decimal places to compare. +* <b>`msg`</b>: The message to be printed if the test fails. +* <b>`delta`</b>: The OK difference between compared values. + + - - - #### `tf.test.TestCase.assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)` {#TestCase.assertSequenceEqual} @@ -599,6 +1000,26 @@ which can be indexed, has a length, and has an equality operator. differences. +- - - + +#### `tf.test.TestCase.assertSequenceStartsWith(prefix, whole, msg=None)` {#TestCase.assertSequenceStartsWith} + +An equality assertion for the beginning of ordered sequences. + +If prefix is an empty sequence, it will raise an error unless whole is also +an empty sequence. + +If prefix is not a sequence, it will raise an error if the first element of +whole does not match. + +##### Args: + + +* <b>`prefix`</b>: A sequence expected at the beginning of the whole parameter. +* <b>`whole`</b>: The sequence in which to look for prefix. +* <b>`msg`</b>: Optional message to report on failure. + + - - - #### `tf.test.TestCase.assertSetEqual(set1, set2, msg=None)` {#TestCase.assertSetEqual} @@ -650,6 +1071,51 @@ Assert that actual.startswith(expected_start) is True. * <b>`msg`</b>: Optional message to report on failure. +- - - + +#### `tf.test.TestCase.assertTotallyOrdered(*groups, **kwargs)` {#TestCase.assertTotallyOrdered} + +Asserts that total ordering has been implemented correctly. + +For example, say you have a class A that compares only on its attribute x. +Comparators other than __lt__ are omitted for brevity. + +class A(object): + def __init__(self, x, y): + self.x = x + self.y = y + + def __hash__(self): + return hash(self.x) + + def __lt__(self, other): + try: + return self.x < other.x + except AttributeError: + return NotImplemented + +assertTotallyOrdered will check that instances can be ordered correctly. +For example, + +self.assertTotallyOrdered( + [None], # None should come before everything else. + [1], # Integers sort earlier. + [A(1, 'a')], + [A(2, 'b')], # 2 is after 1. + [A(3, 'c'), A(3, 'd')], # The second argument is irrelevant. + [A(4, 'z')], + ['foo']) # Strings sort last. + +##### Args: + + +* <b>`*groups`</b>: A list of groups of elements. Each group of elements is a list + of objects that are equal. The elements in each group must be less than + the elements in the group after it. For example, these groups are + totally ordered: [None], [1], [2, 2], [3]. +* <b>`**kwargs`</b>: optional msg keyword argument can be passed. + + - - - #### `tf.test.TestCase.assertTrue(expr, msg=None)` {#TestCase.assertTrue} @@ -672,6 +1138,13 @@ A tuple-specific equality assertion. differences. +- - - + +#### `tf.test.TestCase.assertUrlEqual(a, b, msg=None)` {#TestCase.assertUrlEqual} + +Asserts that urls are equal, ignoring ordering of query params. + + - - - #### `tf.test.TestCase.assert_(expr, msg=None)` {#TestCase.assert_} @@ -733,9 +1206,9 @@ tearDown. - - - -#### `tf.test.TestCase.fail(msg=None)` {#TestCase.fail} +#### `tf.test.TestCase.fail(msg=None, prefix=None)` {#TestCase.fail} -Fail immediately, with the given message. +Fail immediately with the given message, optionally prefixed. - - - @@ -787,6 +1260,13 @@ Fail immediately, with the given message. +- - - + +#### `tf.test.TestCase.getRecordedProperties()` {#TestCase.getRecordedProperties} + +Return any properties that the user has recorded. + + - - - #### `tf.test.TestCase.get_temp_dir()` {#TestCase.get_temp_dir} @@ -809,6 +1289,20 @@ pollute each others environment. +- - - + +#### `tf.test.TestCase.recordProperty(property_name, property_value)` {#TestCase.recordProperty} + +Record an arbitrary property for later use. + +##### Args: + + +* <b>`property_name`</b>: str, name of property to record; must be a valid XML + attribute name +* <b>`property_value`</b>: value of property; must be valid XML attribute value + + - - - #### `tf.test.TestCase.run(result=None)` {#TestCase.run} @@ -834,11 +1328,18 @@ Hook method for setting up class fixture before running tests in the class. #### `tf.test.TestCase.shortDescription()` {#TestCase.shortDescription} -Returns a one-line description of the test, or None if no -description has been provided. +Format both the test method name and the first line of its docstring. -The default implementation of this method returns the first line of -the specified test method's docstring. +If no docstring is given, only returns the method name. + +This method overrides unittest.TestCase.shortDescription(), which +only returns the first line of the docstring, obscuring the name +of the test upon failure. + +##### Returns: + + +* <b>`desc`</b>: A short description of a test method. - - - diff --git a/tensorflow/g3doc/api_docs/python/tf_debug.md b/tensorflow/g3doc/api_docs/python/tf_debug.md index 01d52d2c691..28fc9ec502a 100644 --- a/tensorflow/g3doc/api_docs/python/tf_debug.md +++ b/tensorflow/g3doc/api_docs/python/tf_debug.md @@ -1069,34 +1069,16 @@ Constructor of DumpingDebugWrapperSession. `session_root`. The subdirectories' names has the following pattern: run_<epoch_time_stamp>_<uuid> E.g., run_1480734393835964_ad4c953a85444900ae79fc1b652fb324 -* <b>`watch_fn`</b>: (`Callable`) A Callable of the following signature: - ``` - def watch_fn(fetches, feeds): - # Args: - # fetches: the fetches to the `Session.run()` call. - # feeds: the feeds to the `Session.run()` call. - # - # Returns: (node_name_regex_whitelist, op_type_regex_whitelist) - # debug_ops: (str or list of str) Debug op(s) to be used by the - # debugger in this run() call. - # node_name_regex_whitelist: Regular-expression whitelist for node - # name. Same as the corresponding arg to `debug_util.watch_graph`. - # op_type_regex_whiteslit: Regular-expression whitelist for op type. - # Same as the corresponding arg to `debug_util.watch_graph`. - # - # Both or either can be None. If both are set, the two whitelists - # will operate in a logical AND relation. This is consistent with - # `debug_utils.watch_graph()`. - ``` +* <b>`watch_fn`</b>: (`Callable`) A Callable that can be used to define per-run + debug ops and watched tensors. See the doc of + `NonInteractiveDebugWrapperSession.__init__()` for details. * <b>`log_usage`</b>: (`bool`) whether the usage of this class is to be logged. ##### Raises: * <b>`ValueError`</b>: If `session_root` is an existing and non-empty directory or - if - `session_root` is a file. -* <b>`TypeError`</b>: If a non-None `watch_fn` is specified and it is not callable. + if `session_root` is a file. - - - diff --git a/tensorflow/g3doc/resources/versions.md b/tensorflow/g3doc/resources/versions.md index 34a8e6bc308..a8e211b5c8b 100644 --- a/tensorflow/g3doc/resources/versions.md +++ b/tensorflow/g3doc/resources/versions.md @@ -2,10 +2,9 @@ ## Semantic Versioning 2.0 -Once we reach version 1.0, TensorFlow will follow Semantic Versioning 2.0 -([semver](http://semver.org)) for its public API. Each release version of -TensorFlow has the form `MAJOR.MINOR.PATCH`. Changes to the each number have -the following meaning: +TensorFlow follows Semantic Versioning 2.0 ([semver](http://semver.org)) for its +public API. Each release version of TensorFlow has the form `MAJOR.MINOR.PATCH`. +Changes to the each number have the following meaning: * **MAJOR**: Backwards incompatible changes. Code and data that worked with a previous major release will not necessarily work with a new release. @@ -20,23 +19,23 @@ the following meaning: * **PATCH**: Backwards compatible bug fixes. -Before 1.0, semver allows backwards incompatible changes at any time. However, -to support users now, we will use the format `0.MAJOR.MINOR` (shifted one step -to the right). Thus 0.5.0 to 0.6.0 may be backwards incompatible, but 0.6.0 to -0.6.1 will include only backwards compatible features and bug fixes. - -At some point (especially as we approach 1.0) we will likely use prerelease -versions such as X.Y.Z-alpha.1, but we do not yet have specific plans (beyond -the restrictions of semver). - - ## Public API -Only the C, C++, and Python public APIs of TensorFlow are backwards compatible -across minor and patch versions. The public APIs consist of +Only the public APIs of TensorFlow are backwards compatible across minor and +patch versions. The public APIs consist of -* The documented [Python](../api_docs/python), [C++](../api_docs/cc) and - the [C](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h) APIs. +* The documented public [Python](../api_docs/python) API, excluding `tf.contrib`. + This includes all public functions and classes (with names not starting with + `_`) in the tensorflow module and its submodules. Note that the code in + the `examples/` to `tools/` directories is not reachable through the + tensorflow Python module and is thus not covered by the compatibility + guarantee. + + If a symbol is available through the tensorflow Python module or its + submodules, but is not documented, then it is _not_ considered part of the + public API. + +* The [C API](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h). * The following protocol buffer files: [`attr_value`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto), @@ -50,23 +49,18 @@ across minor and patch versions. The public APIs consist of [`tensor_shape`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor_shape.proto), and [`types`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto). -The public C++ API is exposed through the header files in -[`tensorflow/core/public`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/public). +## Other Languages -The public Python API is unfortunately **not** everything available through the -tensorflow python module and its submodules, since we do not yet use `__all__` -everywhere ([#421](https://github.com/tensorflow/tensorflow/issues/421)). -Please refer to the documentation to determine whether a given Python feature -is part of the public API. For now, the protocol buffers are defined in -[`tensorflow/core/framework/*.proto`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/framework) -([#484](https://github.com/tensorflow/tensorflow/issues/484)). +In addition to Python and C, TensorFlow also provides APIs for: -> The [Java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java) -> ([#5](https://github.com/tensorflow/tensorflow/issues/5)) and -> [Go](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go) APIs -> are experimental and are **not** covered by the versioning scheme at this time. -> They are not guaranteed to backward compatible between releases. +- [C++](../api_docs/cc) (exposed through header files in +[`tensorflow/cc`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/cc). +- [Java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java) +([#5](https://github.com/tensorflow/tensorflow/issues/5)), and +- [Go](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go) +However, these three are **not** covered by the versioning scheme at this time +and can be changed in backward incompatible ways between releases. ## Details That Are Not Public @@ -76,7 +70,7 @@ fixes require it: * **Details of composite ops:** Many public functions in Python expand to several primitive ops in the graph, and these details will be part of any - graphs saved to disk as GraphDefs. These details are allowed to change for + graphs saved to disk as `GraphDef`s. These details are allowed to change for minor releases. In particular, regressions tests that check for exact matching between graphs are likely to break across minor releases, even though the behavior of the graph should be unchanged and existing checkpoints will @@ -98,7 +92,7 @@ fixes require it: such intended changes will be documented. -## Compatibility for Graphs and Checkpoints {#graphs} +## Compatibility for Graphs and Checkpoints Many users of TensorFlow will be saving graphs and trained models to disk for later evaluation or more training, often changing versions of TensorFlow in the @@ -145,11 +139,3 @@ provide tools for automatically converting graphs to a newer supported For developer-level details about `GraphDef` versioning, including how to evolve the versions to account for changes, see [TensorFlow Data Versioning](data_versions.md). - - -## C++ ABI Compatibility - -Only patch releases will be binary compatible at the C++ level. That is, minor -releases are backwards compatible in terms of behavior but may require a -recompile for downstream C++ code. As always, backwards compatibility is only -provided for the public C++ API. diff --git a/tensorflow/g3doc/tutorials/index.md b/tensorflow/g3doc/tutorials/index.md index edc1f6b5a44..505f1b42706 100644 --- a/tensorflow/g3doc/tutorials/index.md +++ b/tensorflow/g3doc/tutorials/index.md @@ -8,37 +8,33 @@ digit images. ### MNIST For ML Beginners -If you're new to machine learning, we recommend starting here. You'll learn +If you're new to machine learning, we recommend starting here. You'll learn about a classic problem, handwritten digit classification (MNIST), and get a gentle introduction to multiclass classification. [View Tutorial](../tutorials/mnist/beginners/index.md) - ### Deep MNIST for Experts If you're already familiar with other deep learning software packages, and are -already familiar with MNIST, this tutorial will give you a very brief primer -on TensorFlow. +already familiar with MNIST, this tutorial will give you a very brief primer on +TensorFlow. [View Tutorial](../tutorials/mnist/pros/index.md) ### TensorFlow Mechanics 101 This is a technical tutorial, where we walk you through the details of using -TensorFlow infrastructure to train models at scale. We use MNIST as the -example. +TensorFlow infrastructure to train models at scale. We use MNIST as the example. [View Tutorial](../tutorials/mnist/tf/index.md) - ## Easy ML with tf.contrib.learn ### tf.contrib.learn Quickstart A quick introduction to tf.contrib.learn, a high-level API for TensorFlow. -Build, train, and evaluate a neural network with just a few lines of -code. +Build, train, and evaluate a neural network with just a few lines of code. [View Tutorial](../tutorials/tflearn/index.md) @@ -73,19 +69,27 @@ Monitor API to audit the in-progress training of a neural network. ### Building Input Functions with tf.contrib.learn This tutorial introduces you to creating input functions in tf.contrib.learn, -and walks you through implementing an `input_fn` to train a neural network -for predicting median house values. +and walks you through implementing an `input_fn` to train a neural network for +predicting median house values. [View Tutorial](../tutorials/input_fn/index.md) ### Creating Estimators in tf.contrib.learn -This tutorial covers how to create your own `Estimator` using the building blocks -provided in tf.contrib.learn. You'll build a model to predict the ages of abalones -based on their physical measurements. +This tutorial covers how to create your own `Estimator` using the building +blocks provided in tf.contrib.learn. You'll build a model to predict the ages of +abalones based on their physical measurements. [View Tutorial](../tutorials/estimators/index.md) +### A Guide to TF Layers: Building a Convolutional Neural Network + +This tutorial introduces you to building neural networks in TensorFlow using the +`tf.layers` module. You'll build a convolutional neural network `Estimator` to +recognize the handwritten digits in the MNIST data set. + +[View Tutorial](../tutorials/layers/index.md) + ## TensorFlow Serving ### TensorFlow Serving @@ -95,7 +99,6 @@ serving machine learning models, designed for production environments. [View Tutorial](../tutorials/tfserve/index.md) - ## Image Processing ### Convolutional Neural Networks @@ -109,8 +112,8 @@ representations of visual content. ### Image Recognition -How to run object recognition using a convolutional neural network -trained on ImageNet Challenge data and label set. +How to run object recognition using a convolutional neural network trained on +ImageNet Challenge data and label set. [View Tutorial](../tutorials/image_recognition/index.md) @@ -120,8 +123,8 @@ Building on the Inception recognition model, we will release a TensorFlow version of the [Deep Dream](https://github.com/google/deepdream) neural network visual hallucination software. -[View Tutorial](https://nbviewer.jupyter.org/github/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/deepdream/deepdream.ipynb) - +[View +Tutorial](https://nbviewer.jupyter.org/github/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/deepdream/deepdream.ipynb) ## Language and Sequence Processing @@ -138,14 +141,14 @@ embeddings). ### Recurrent Neural Networks An introduction to RNNs, wherein we train an LSTM network to predict the next -word in an English sentence. (A task sometimes called language modeling.) +word in an English sentence. (A task sometimes called language modeling.) [View Tutorial](../tutorials/recurrent/index.md) ### Sequence-to-Sequence Models A follow on to the RNN tutorial, where we assemble a sequence-to-sequence model -for machine translation. You will learn to build your own English-to-French +for machine translation. You will learn to build your own English-to-French translator, entirely machine learned, end-to-end. [View Tutorial](../tutorials/seq2seq/index.md) @@ -157,19 +160,18 @@ TensorFlow. [View Tutorial](../tutorials/syntaxnet/index.md) - ## Non-ML Applications ### Mandelbrot Set TensorFlow can be used for computation that has nothing to do with machine -learning. Here's a naive implementation of Mandelbrot set visualization. +learning. Here's a naive implementation of Mandelbrot set visualization. [View Tutorial](../tutorials/mandelbrot/index.md) ### Partial Differential Equations -As another example of non-machine learning computation, we offer an example of -a naive PDE simulation of raindrops landing on a pond. +As another example of non-machine learning computation, we offer an example of a +naive PDE simulation of raindrops landing on a pond. [View Tutorial](../tutorials/pdes/index.md) diff --git a/tensorflow/g3doc/tutorials/layers/index.md b/tensorflow/g3doc/tutorials/layers/index.md index 2d0071a31ac..387b6e0dfa9 100644 --- a/tensorflow/g3doc/tutorials/layers/index.md +++ b/tensorflow/g3doc/tutorials/layers/index.md @@ -45,7 +45,7 @@ evaluate the convolutional neural network. The complete, final code can be here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/layers/cnn_mnist.py). <p class="note"><b>NOTE:</b> Before proceeding, make sure you've -<a href="https://www.tensorflow.org/get_started/os_setup">installed the latest +<a href="../../get_started/os_setup.md">installed the latest version of TensorFlow</a> on your machine.</p> ## Intro to Convolutional Neural Networks @@ -87,9 +87,9 @@ is equal to 1). We can interpret the softmax values for a given image as relative measurements of how likely it is that the image falls into each target class. -NOTE: For a more comprehensive walkthrough of CNN architecture, see Stanford -University's [Convolutional Neural Networks for Visual Recognition course -materials](http://cs231n.github.io/convolutional-networks/). +<p class="note"><b>NOTE:</b> For a more comprehensive walkthrough of CNN +architecture, see Stanford University's <a href="http://cs231n.github.io/convolutional-networks/"> +Convolutional Neural Networks for Visual Recognition course materials</a>.</p> ## Building the CNN MNIST Classifier {#building-cnn-classifier} @@ -506,7 +506,7 @@ if mode == learn.ModeKeys.TRAIN: <p class="note"><b>NOTE:</b> For a more in-depth look at configuring training ops for Estimator model functions, see <a href="../estimators/index.md#defining_the_training_op_for_the_model">"Defining the training op for the model"</a> in the -<a href="../estimators/index.md">"Creating Estimations in tf.contrib.learn"]</a> tutorial.</p> +<a href="../estimators/index.md">"Creating Estimations in tf.contrib.learn"</a> tutorial.</p> ### Generate Predictions {#generate-predictions} @@ -541,7 +541,7 @@ using [`tf.nn.softmax()`](../../api_docs/python/nn.md#softmax): tf.nn.softmax(logits, name="softmax_tensor") ``` -<p class="note"><b>NOTE:</b We use the `name` argument to explicitly name this operation `softmax_tensor`, so we can reference it later. (We'll set up logging for the softmax values in <a href="#set-up-a-logging-hook">Set Up a Logging Hook</a>.)</p> +<p class="note"><b>NOTE:</b> We use the `name` argument to explicitly name this operation `softmax_tensor`, so we can reference it later. (We'll set up logging for the softmax values in <a href="#set-up-a-logging-hook">Set Up a Logging Hook</a>.)</p> We compile our predictions in a dict as follows: diff --git a/tensorflow/g3doc/tutorials/leftnav_files b/tensorflow/g3doc/tutorials/leftnav_files index a75e62f5e36..77ec0a0f39f 100644 --- a/tensorflow/g3doc/tutorials/leftnav_files +++ b/tensorflow/g3doc/tutorials/leftnav_files @@ -10,6 +10,7 @@ wide_and_deep/index.md monitors/index.md input_fn/index.md estimators/index.md +layers/index.md ### TensorFlow Serving tfserve/index.md ### Image Processing diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3fe7cc1bb7f..626447d7833 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -45,6 +45,7 @@ py_library( ":check_ops", ":client", ":client_testlib", + ":cloud_ops", ":confusion_matrix", ":control_flow_ops", ":errors", @@ -75,6 +76,7 @@ py_library( ":ops", ":test_ops", # TODO: Break testing code out into separate rule. ":util", + ":weights_broadcast_ops", "//third_party/py/numpy", "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model", @@ -120,6 +122,38 @@ py_library( deps = [":platform_benchmark"], ) +py_library( + name = "cloud_ops", + srcs = [ + "ops/cloud/__init__.py", + "ops/cloud/bigquery_reader_ops.py", + "ops/cloud/cloud.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cloud_ops_gen", + ":framework_for_generated_wrappers", + ], +) + +tf_py_test( + name = "bigquery_reader_ops_test", + size = "small", + srcs = ["ops/cloud/bigquery_reader_ops_test.py"], + additional_deps = [ + ":array_ops", + ":client_testlib", + ":cloud_ops", + ":data_flow_ops", + ":io_ops", + ":parsing_ops", + ":util", + "//tensorflow/core/kernels/cloud:bigquery_reader_ops", + "//tensorflow/core:cloud_ops_op_lib", + ], + tags = ["manual"], +) + tf_py_test( name = "resource_loader_test", size = "small", @@ -791,6 +825,11 @@ tf_gen_op_wrapper_private_py( visibility = ["//learning/brain/python/ops:__pkg__"], ) +tf_gen_op_wrapper_private_py( + name = "cloud_ops_gen", + require_shape_functions = True, +) + tf_gen_op_wrapper_private_py( name = "control_flow_ops_gen", require_shape_functions = True, @@ -1515,6 +1554,21 @@ py_library( ], ) +py_library( + name = "weights_broadcast_ops", + srcs = [ + "ops/weights_broadcast_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":control_flow_ops", + ":framework", + ":math_ops", + ":sets", + ], +) + py_library( name = "metrics", srcs = [ @@ -1537,6 +1591,7 @@ py_library( ":util", ":variable_scope", ":variables", + ":weights_broadcast_ops", ], ) diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index e2ed3f3f8e1..39ec56ca327 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -261,13 +261,13 @@ _allowed_symbols.extend([ # Remove all extra symbols that don't have a docstring or are not explicitly # referenced in the whitelist. -remove_undocumented(__name__, _allowed_symbols, - [framework_lib, array_ops, client_lib, check_ops, - compat, constant_op, control_flow_ops, confusion_matrix_m, - functional_ops, histogram_ops, io_ops, losses, math_ops, - metrics, nn, resource_loader, sets, script_ops, - session_ops, sparse_ops, state_ops, string_ops, summary, - tensor_array_ops, train, layers]) +remove_undocumented(__name__, _allowed_symbols, [ + framework_lib, array_ops, check_ops, client_lib, compat, constant_op, + control_flow_ops, confusion_matrix_m, functional_ops, histogram_ops, io_ops, + losses, math_ops, metrics, nn, resource_loader, sets, script_ops, + session_ops, sparse_ops, state_ops, string_ops, summary, tensor_array_ops, + train, layers +]) # Special dunders that we choose to export: _exported_dunders = set([ diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 4b968c150d9..0aa5ce0a60b 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -193,6 +193,13 @@ py_library( ], ) +py_library( + name = "grpc_wrapper", + srcs = ["wrappers/grpc_wrapper.py"], + srcs_version = "PY2AND3", + deps = [":framework"], +) + py_library( name = "local_cli_wrapper", srcs = ["wrappers/local_cli_wrapper.py"], @@ -551,6 +558,7 @@ py_test( deps = [ ":dumping_wrapper", ":hooks", + ":stepper", "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper.py b/tensorflow/python/debug/wrappers/dumping_wrapper.py index d0e2c3ea20e..dc3f8468c61 100644 --- a/tensorflow/python/debug/wrappers/dumping_wrapper.py +++ b/tensorflow/python/debug/wrappers/dumping_wrapper.py @@ -28,7 +28,7 @@ from tensorflow.python.debug.wrappers import framework from tensorflow.python.platform import gfile -class DumpingDebugWrapperSession(framework.BaseDebugWrapperSession): +class DumpingDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): """Debug Session wrapper that dumps debug data to filesystem.""" def __init__(self, sess, session_root, watch_fn=None, log_usage=True): @@ -45,44 +45,21 @@ class DumpingDebugWrapperSession(framework.BaseDebugWrapperSession): `session_root`. The subdirectories' names has the following pattern: run_<epoch_time_stamp>_<uuid> E.g., run_1480734393835964_ad4c953a85444900ae79fc1b652fb324 - watch_fn: (`Callable`) A Callable of the following signature: - ``` - def watch_fn(fetches, feeds): - # Args: - # fetches: the fetches to the `Session.run()` call. - # feeds: the feeds to the `Session.run()` call. - # - # Returns: (node_name_regex_whitelist, op_type_regex_whitelist) - # debug_ops: (str or list of str) Debug op(s) to be used by the - # debugger in this run() call. - # node_name_regex_whitelist: Regular-expression whitelist for node - # name. Same as the corresponding arg to `debug_util.watch_graph`. - # op_type_regex_whiteslit: Regular-expression whitelist for op type. - # Same as the corresponding arg to `debug_util.watch_graph`. - # - # Both or either can be None. If both are set, the two whitelists - # will operate in a logical AND relation. This is consistent with - # `debug_utils.watch_graph()`. - ``` + watch_fn: (`Callable`) A Callable that can be used to define per-run + debug ops and watched tensors. See the doc of + `NonInteractiveDebugWrapperSession.__init__()` for details. log_usage: (`bool`) whether the usage of this class is to be logged. Raises: ValueError: If `session_root` is an existing and non-empty directory or - if - `session_root` is a file. - TypeError: If a non-None `watch_fn` is specified and it is not callable. + if `session_root` is a file. """ if log_usage: pass # No logging for open-source. - framework.BaseDebugWrapperSession.__init__(self, sess) - - self._watch_fn = None - if watch_fn is not None: - if not callable(watch_fn): - raise TypeError("watch_fn is not callable") - self._watch_fn = watch_fn + framework.NonInteractiveDebugWrapperSession.__init__( + self, sess, watch_fn=watch_fn) if gfile.Exists(session_root): if not gfile.IsDirectory(session_root): @@ -94,47 +71,21 @@ class DumpingDebugWrapperSession(framework.BaseDebugWrapperSession): session_root) self._session_root = session_root - def on_session_init(self, request): - """See doc of BaseDebugWrapperSession.on_run_start.""" + def _prepare_run_debug_urls(self, fetches, feed_dict): + """Implementation of abstrat method in superclass. - return framework.OnSessionInitResponse( - framework.OnSessionInitAction.PROCEED) - - def on_run_start(self, request): - """See doc of BaseDebugWrapperSession.on_run_start.""" - - (debug_urls, debug_ops, node_name_regex_whitelist, - op_type_regex_whitelist) = self._prepare_run_watch_config( - request.fetches, request.feed_dict) - - return framework.OnRunStartResponse( - framework.OnRunStartAction.DEBUG_RUN, - debug_urls, - debug_ops=debug_ops, - node_name_regex_whitelist=node_name_regex_whitelist, - op_type_regex_whitelist=op_type_regex_whitelist) - - def _prepare_run_watch_config(self, fetches, feed_dict): - """Get the debug_urls, and node/op whitelists for the current run() call. - - Prepares a directory with a fixed naming pattern. Saves Event proto files - of names `_tfdbg_run_fetches_info` and `_tfdbg_run_feed_keys_info` in the - directory to save information about the `fetches` and `feed_dict.keys()` - used in this `run()` call, respectively. + See doc of `NonInteractiveDebugWrapperSession.__prepare_run_debug_urls()` + for details. This implentation creates a run-specific subdirectory under + self._session_root and stores information regarding run `fetches` and + `feed_dict.keys()` in the subdirectory. Args: - fetches: Same as the `fetches` argument to `Session.run()`. - feed_dict: Same as the `feed_dict argument` to `Session.run()`. + fetches: Same as the `fetches` argument to `Session.run()` + feed_dict: Same as the `feed_dict` argument to `Session.run()` Returns: - debug_urls: (str or list of str) Debug URLs for the current run() call. - Currently, the list consists of only one URL that is a file:// URL. - debug_ops: (str or list of str) Debug op(s) to be used by the - debugger. - node_name_regex_whitelist: (str or regex) Regular-expression whitelist for - node name. Same as the same-name argument to debug_utils.watch_graph. - op_type_regex_whitelist: (str or regex) Regular-expression whitelist for - op type. Same as the same-name argument to debug_utils.watch_graph. + debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in + this `Session.run()` call. """ # Add a UUID to accommodate the possibility of concurrent run() calls. @@ -160,24 +111,4 @@ class DumpingDebugWrapperSession(framework.BaseDebugWrapperSession): with gfile.Open(os.path.join(feed_keys_path), "wb") as f: f.write(feed_keys_event.SerializeToString()) - debug_ops, node_name_regex_whitelist, op_type_regex_whitelist = ( - "DebugIdentity", None, None) - if self._watch_fn is not None: - debug_ops, node_name_regex_whitelist, op_type_regex_whitelist = ( - self._watch_fn(fetches, feed_dict)) - - return (["file://" + run_dir], debug_ops, node_name_regex_whitelist, - op_type_regex_whitelist) - - def on_run_end(self, request): - """See doc of BaseDebugWrapperSession.on_run_end.""" - - return framework.OnRunEndResponse() - - def invoke_node_stepper(self, - node_stepper, - restore_variable_values_on_exit=True): - """See doc of BaseDebugWrapperSession.invoke_node_stepper.""" - - return NotImplementedError( - "DumpingDebugWrapperSession does not support node-stepper mode.") + return ["file://" + run_dir] diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py index 568c55e8ef4..2be1077e28e 100644 --- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py @@ -24,6 +24,7 @@ import tempfile from tensorflow.python.client import session from tensorflow.python.debug import debug_data +from tensorflow.python.debug import stepper from tensorflow.python.debug.wrappers import dumping_wrapper from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op @@ -259,6 +260,16 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info) + def testCallingInvokeNodeStepperOnDumpingWrapperRaisesException(self): + sess = dumping_wrapper.DumpingDebugWrapperSession( + self.sess, session_root=self.session_root, log_usage=False) + node_stepper = stepper.NodeStepper(self.sess, self.inc_v) + with self.assertRaisesRegexp( + NotImplementedError, + r"NonInteractiveDebugWrapperSession does not support node-stepper " + r"mode\."): + sess.invoke_node_stepper(node_stepper) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index cbc1fa26032..145008e9024 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -325,10 +325,17 @@ class BaseDebugWrapperSession(session.SessionInterface): Raises: ValueError: On invalid `OnSessionInitAction` value. + NotImplementedError: If a non-DirectSession sess object is received. """ _check_type(sess, session.BaseSession) + # TODO(cais): Remove this check once tfdbg is integrated with GrpcSession. + if sess.sess_str: + raise NotImplementedError( + "Non-DirectSession support is not available from TensorFlow " + "Debugger yet (sess_str=%s)" % sess.sess_str) + # The session being wrapped. self._sess = sess @@ -564,3 +571,120 @@ class BaseDebugWrapperSession(session.SessionInterface): The same return values as the `Session.run()` call on the same fetches as the NodeStepper. """ + + +class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): + """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions.""" + + def __init__(self, sess, watch_fn=None): + """Constructor of DumpingDebugWrapperSession. + + Args: + sess: The TensorFlow `Session` object being wrapped. + watch_fn: (`Callable`) A Callable of the following signature: + ``` + def watch_fn(fetches, feeds): + # Args: + # fetches: the fetches to the `Session.run()` call. + # feeds: the feeds to the `Session.run()` call. + # + # Returns: (node_name_regex_whitelist, op_type_regex_whitelist) + # debug_ops: (str or list of str) Debug op(s) to be used by the + # debugger in this run() call. + # node_name_regex_whitelist: Regular-expression whitelist for node + # name. Same as the corresponding arg to `debug_util.watch_graph`. + # op_type_regex_whiteslit: Regular-expression whitelist for op type. + # Same as the corresponding arg to `debug_util.watch_graph`. + # + # Both or either can be None. If both are set, the two whitelists + # will operate in a logical AND relation. This is consistent with + # `debug_utils.watch_graph()`. + ``` + + Raises: + TypeError: If a non-None `watch_fn` is specified and it is not callable. + """ + + BaseDebugWrapperSession.__init__(self, sess) + + self._watch_fn = None + if watch_fn is not None: + if not callable(watch_fn): + raise TypeError("watch_fn is not callable") + self._watch_fn = watch_fn + + def on_session_init(self, request): + """See doc of BaseDebugWrapperSession.on_run_start.""" + + return OnSessionInitResponse(OnSessionInitAction.PROCEED) + + @abc.abstractmethod + def _prepare_run_debug_urls(self, fetches, feed_dict): + """Abstract method to be implemented by concrete subclasses. + + This method prepares the run-specific debug URL(s). + + Args: + fetches: Same as the `fetches` argument to `Session.run()` + feed_dict: Same as the `feed_dict` argument to `Session.run()` + + Returns: + debug_urls: (`str` or `list` of `str`) Debug URLs to be used in + this `Session.run()` call. + """ + + def on_run_start(self, request): + """See doc of BaseDebugWrapperSession.on_run_start.""" + + (debug_urls, debug_ops, node_name_regex_whitelist, + op_type_regex_whitelist) = self._prepare_run_watch_config( + request.fetches, request.feed_dict) + + return OnRunStartResponse( + OnRunStartAction.DEBUG_RUN, + debug_urls, + debug_ops=debug_ops, + node_name_regex_whitelist=node_name_regex_whitelist, + op_type_regex_whitelist=op_type_regex_whitelist) + + def _prepare_run_watch_config(self, fetches, feed_dict): + """Get the debug_urls, and node/op whitelists for the current run() call. + + Args: + fetches: Same as the `fetches` argument to `Session.run()`. + feed_dict: Same as the `feed_dict argument` to `Session.run()`. + + Returns: + debug_urls: (str or list of str) Debug URLs for the current run() call. + Currently, the list consists of only one URL that is a file:// URL. + debug_ops: (str or list of str) Debug op(s) to be used by the + debugger. + node_name_regex_whitelist: (str or regex) Regular-expression whitelist for + node name. Same as the same-name argument to debug_utils.watch_graph. + op_type_regex_whitelist: (str or regex) Regular-expression whitelist for + op type. Same as the same-name argument to debug_utils.watch_graph. + """ + + debug_urls = self._prepare_run_debug_urls(fetches, feed_dict) + debug_ops = "DebugIdentity" + node_name_regex_whitelist = None + op_type_regex_whitelist = None + if self._watch_fn is not None: + debug_ops, node_name_regex_whitelist, op_type_regex_whitelist = ( + self._watch_fn(fetches, feed_dict)) + + return (debug_urls, debug_ops, node_name_regex_whitelist, + op_type_regex_whitelist) + + def on_run_end(self, request): + """See doc of BaseDebugWrapperSession.on_run_end.""" + + return OnRunEndResponse() + + def invoke_node_stepper(self, + node_stepper, + restore_variable_values_on_exit=True): + """See doc of BaseDebugWrapperSession.invoke_node_stepper.""" + + raise NotImplementedError( + "NonInteractiveDebugWrapperSession does not support node-stepper mode.") diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py index 09d635cf3fa..d56c7057f66 100644 --- a/tensorflow/python/debug/wrappers/framework_test.py +++ b/tensorflow/python/debug/wrappers/framework_test.py @@ -311,6 +311,18 @@ class DebugWrapperSessionTest(test_util.TensorFlowTestCase): self._observer) wrapper.close() + def testUsingNonDirectSessionRaisesNotImplementedError(self): + # TODO(cais): Remove this test once tfdbg is integrated with GrpcSession. + fake_non_direct_session = session.Session() + fake_non_direct_session._target = "foo" + + with self.assertRaisesRegexp( + NotImplementedError, + r"Non-DirectSession support is not available from TensorFlow Debugger " + r"yet \(sess_str=foo\)"): + TestDebugWrapperSession( + fake_non_direct_session, self._dump_root, self._observer) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py new file mode 100644 index 00000000000..a56abc2c19d --- /dev/null +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -0,0 +1,91 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Debugger wrapper session that sends debug data to file:// URLs.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Google-internal import(s). +from tensorflow.python.debug.wrappers import framework + + +class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): + """Debug Session wrapper that send debug data to gRPC stream(s).""" + + _GRPC_URL_PREFIX = "grpc://" + + def __init__(self, + sess, + grpc_debug_server_addresses, + watch_fn=None, + log_usage=True): + """Constructor of DumpingDebugWrapperSession. + + Args: + sess: The TensorFlow `Session` object being wrapped. + grpc_debug_server_addresses: (`str` or `list` of `str`) Single or a list + of the gRPC debug server addresses, in the format of + <host:port>, without the "grpc://" prefix. For example: + "localhost:7000", + ["localhost:7000", "192.168.0.2:8000"] + watch_fn: (`Callable`) A Callable that can be used to define per-run + debug ops and watched tensors. See the doc of + `NonInteractiveDebugWrapperSession.__init__()` for details. + log_usage: (`bool`) whether the usage of this class is to be logged. + + Raises: + TypeError: If `grpc_debug_server_addresses` is not a `str` or a `list` + of `str`. + """ + + if log_usage: + pass # No logging for open-source. + + framework.NonInteractiveDebugWrapperSession.__init__( + self, sess, watch_fn=watch_fn) + + if isinstance(grpc_debug_server_addresses, str): + self._grpc_debug_server_urls = [ + self._GRPC_URL_PREFIX + grpc_debug_server_addresses + ] + elif isinstance(grpc_debug_server_addresses, list): + self._grpc_debug_server_urls = [] + for address in grpc_debug_server_addresses: + if not isinstance(address, str): + raise TypeError( + "Expected type str in list grpc_debug_server_addresses, " + "received type %s" % type(address)) + self._grpc_debug_server_urls.append(self._GRPC_URL_PREFIX + address) + else: + raise TypeError( + "Expected type str or list in grpc_debug_server_addresses, " + "received type %s" % type(grpc_debug_server_addresses)) + + def _prepare_run_debug_urls(self, fetches, feed_dict): + """Implementation of abstract method in superclass. + + See doc of `NonInteractiveDebugWrapperSession.__prepare_run_debug_urls()` + for details. + + Args: + fetches: Same as the `fetches` argument to `Session.run()` + feed_dict: Same as the `feed_dict` argument to `Session.run()` + + Returns: + debug_urls: (`str` or `list` of `str`) file:// debug URLs to be used in + this `Session.run()` call. + """ + + return self._grpc_debug_server_urls diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 5f6f066cdb2..7c387e1da44 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -169,6 +169,7 @@ def all_libraries(module_to_name, members, documented): "Inputs and Readers", exclude_symbols=["LookupTableBase", "HashTable", "initialize_all_tables", + "tables_initializer", "parse_single_sequence_example", "string_to_hash_bucket"], prefix=PREFIX_TEXT), diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index 37d8734a0f5..49406eedf30 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -151,12 +151,8 @@ def ops_used_by_graph_def(graph_def): mark_op_as_used(node.op) while functions_to_process: fun = functions_to_process.pop() - if fun.node_def: - for node in fun.node_def: - mark_op_as_used(node.op) - else: # TODO(josh11b): Eventually remove this case. - for node in fun.node: - mark_op_as_used(node.op) + for node in fun.node_def: + mark_op_as_used(node.op) return [op for op in used_ops if op not in name_to_function] diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index de03c6ac7fe..dfee19c16fb 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -638,6 +638,11 @@ def _ConstantValue(tensor): return np.concatenate(values, axis=dim) elif tensor.op.type == "Pack": values = [] + # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid + # and shouldn't be produced, but to deal sensibly with them here we check + # and return None. + if not tensor.op.inputs: + return None for x in tensor.op.inputs: value = constant_value(x) if value is None: diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 906db3a6622..f4c3dcf99fb 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -341,17 +341,6 @@ tf_py_test( ], ) -tf_py_test( - name = "inplace_ops_test", - size = "small", - srcs = ["inplace_ops_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:nn_ops", - ], -) - tf_py_test( name = "io_ops_test", size = "small", @@ -2467,6 +2456,24 @@ tf_py_test( ], ) +tf_py_test( + name = "weights_broadcast_test", + size = "small", + srcs = ["weights_broadcast_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:weights_broadcast_ops", + ], + shard_count = 3, +) + tf_py_test( name = "metrics_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py deleted file mode 100644 index 97bb5c43647..00000000000 --- a/tensorflow/python/kernel_tests/inplace_ops_test.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for inplace_ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.client import session -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class InplaceOpsTest(test.TestCase): - - def testBasicUpdate(self): - for dtype in [dtypes.float32, dtypes.int32]: - with self.test_session(use_gpu=True): - x = array_ops.ones([7, 3], dtype) - y = np.ones([7, 3], dtype.as_numpy_dtype) - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_update(x, [3], - array_ops.ones([1, 3], dtype)) - y[3, :] = 1 - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_update(x, [-1], - array_ops.ones([1, 3], dtype) * 2) - y[-1, :] = 2 - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_update(x, 5, array_ops.ones([3], dtype) * - 7) - y[5, :] = 7 - self.assertAllClose(x.eval(), y) - - def testBasicAdd(self): - for dtype in [dtypes.float32, dtypes.int32]: - with self.test_session(use_gpu=True): - x = array_ops.ones([7, 3], dtype) - y = np.ones([7, 3], dtype.as_numpy_dtype) - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_add(x, [3], array_ops.ones([1, 3], dtype)) - y[3, :] += 1 - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_add(x, [-1], - array_ops.ones([1, 3], dtype) * 2) - y[-1, :] += 2 - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_add(x, 5, array_ops.ones([3], dtype) * 7) - y[5, :] += 7 - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_add(x, None, - array_ops.ones([7, 3], dtype) * 99) - y[:, :] += 99 - self.assertAllClose(x.eval(), y) - - def testBasicSub(self): - for dtype in [dtypes.float32, dtypes.int32]: - with self.test_session(use_gpu=True): - x = array_ops.ones([7, 3], dtype) - y = np.ones([7, 3], dtype.as_numpy_dtype) - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_subtract(x, [3], - array_ops.ones([1, 3], dtype)) - y[3, :] -= 1 - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_subtract(x, [-1], - array_ops.ones([1, 3], dtype) * 2) - y[-1, :] -= 2 - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_subtract(x, 5, - array_ops.ones([3], dtype) * 7) - y[5, :] -= 7 - self.assertAllClose(x.eval(), y) - x = array_ops._alias_inplace_subtract(x, None, - array_ops.ones([7, 3], dtype) * - 99) - y[:, :] -= 99 - self.assertAllClose(x.eval(), y) - - def testRandom(self): - with self.test_session(use_gpu=True): - d0, d1, d2 = 100, 3, 5 - x = array_ops.zeros([d0, d1, d2]) - y = np.zeros([d0, d1, d2]) - for _ in range(20): - idx = np.random.choice(d0, d0 / 10, replace=False) - val = np.random.randint(10, size=(d0 / 10, d1, d2)) - op = np.random.randint(3) - if op == 0: - x = array_ops._alias_inplace_update(x, idx, val) - y[idx, :] = val - elif op == 1: - x = array_ops._alias_inplace_add(x, idx, val) - y[idx, :] += val - elif op == 2: - x = array_ops._alias_inplace_subtract(x, idx, val) - y[idx, :] -= val - self.assertAllClose(x.eval(), y) - - def testRandom1D(self): - with self.test_session(use_gpu=True): - d0 = 100 - x = array_ops.zeros([d0]) - y = np.zeros([d0]) - for _ in range(20): - idx = np.random.choice(d0, d0 / 10, replace=False) - val = np.random.randint(10, size=(d0 / 10)) - op = np.random.randint(3) - if op == 0: - x = array_ops._alias_inplace_update(x, idx, val) - y[idx] = val - elif op == 1: - x = array_ops._alias_inplace_add(x, idx, val) - y[idx] += val - elif op == 2: - x = array_ops._alias_inplace_subtract(x, idx, val) - y[idx] -= val - self.assertAllClose(x.eval(), y) - - def testError(self): - with self.test_session(): - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "must be a vector"): - _ = array_ops._alias_inplace_update([[1.]], [[0]], [[10]]).eval() - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "value and update shape doesn't match"): - _ = array_ops._alias_inplace_update([[1.]], [0], [10]).eval() - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "loc and update shape doesn't match"): - _ = array_ops._alias_inplace_update([[1.]], [0, 1], [[10]]).eval() - - def testEmpty(self): - # Not much to test except the output a empty should have the shape - # and dtype we specify. - for dtype in [dtypes.float32, dtypes.float64, dtypes.int32]: - with self.test_session(use_gpu=True): - test_shapes = [(), (1,), (2, 3), (0, 2), (2, 3, 5), (2, 0, 5)] - for shape in test_shapes: - val = array_ops._empty(shape, dtype).eval() - self.assertEqual(val.shape, shape) - self.assertEqual(val.dtype, dtype.as_numpy_dtype) - val = array_ops._empty(shape, dtype, init=True).eval() - self.assertEqual(val.shape, shape) - self.assertEqual(val.dtype, dtype.as_numpy_dtype) - self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype)) - val = array_ops._empty_like(array_ops.zeros(shape, dtype)).eval() - self.assertEqual(val.shape, shape) - self.assertEqual(val.dtype, dtype.as_numpy_dtype) - val = array_ops._empty_like( - array_ops.zeros(shape, dtype), init=True).eval() - self.assertEqual(val.shape, shape) - self.assertEqual(val.dtype, dtype.as_numpy_dtype) - self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype)) - - val = array_ops._empty((1, 2), dtypes.string, init=True).eval() - self.assertEqual(val.tolist(), [[b"", b""]]) - - val = array_ops._empty((1, 2), dtypes.string, init=False).eval() - self.assertEqual(val.tolist(), [[b"", b""]]) - - def testEmptyStateful(self): - with session.Session("") as sess: - v1 = array_ops.placeholder(dtypes.float32, shape=[]) - v2 = array_ops.placeholder(dtypes.float32, shape=[]) - - a = array_ops._empty((1,), dtypes.float32, init=False) - b = array_ops._empty((1,), dtypes.float32, init=False) - - a = array_ops._alias_inplace_update(a, 0, v1) - b = array_ops._alias_inplace_update(b, 0, v2) - - res1, res2 = sess.run([a, b], feed_dict={v1: 1, v2: 2}) - self.assertEqual(res1, 1) - self.assertEqual(res2, 2) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index 07d805b90ec..fc021c897a0 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -297,22 +297,27 @@ class MeanTest(test.TestCase): def testInvalidWeights(self): values_placeholder = array_ops.placeholder(dtype=dtypes_lib.float32) - values = _test_values((3, 2, 4)) + values = _test_values((3, 2, 4, 1)) invalid_weights = ( (1,), (1, 1), + (1, 1, 1), (3, 2), - (2, 4), - (1, 1, 1, 1), - (3, 2, 4, 1),) + (3, 2, 4), + (2, 4, 1), + (4, 2, 4, 1), + (3, 3, 4, 1), + (3, 2, 5, 1), + (3, 2, 4, 2), + (1, 1, 1, 1, 1)) + expected_error_msg = 'weights can not be broadcast to values' for invalid_weight in invalid_weights: # Static shapes. - with self.assertRaisesRegexp(ValueError, 'must have rank in.*0.*3'): + with self.assertRaisesRegexp(ValueError, expected_error_msg): metrics.mean(values, invalid_weight) # Dynamic shapes. - with self.assertRaisesRegexp( - errors_impl.OpError, 'must have rank in.*0.*3'): + with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg): with self.test_session(): _, update_op = metrics.mean(values_placeholder, invalid_weight) variables.local_variables_initializer().run() diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index bad6a2fc78a..b470919dcac 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -52,7 +52,7 @@ def _FlatOuterDims(tensor, ndims=2): def _NumpyScatterNd(ref, indices, updates, op): ixdim = indices.shape[-1] - num_updates = indices.size / ixdim + num_updates = indices.size // ixdim total_nd = len(ref.shape) slice_size = 1 for i in range(ixdim, total_nd): diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 7c1842a9cd4..72edf482704 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -453,6 +453,25 @@ class VariableScopeTest(test.TestCase): variable_scope.get_variable("w", []).name, "defaultScope1_2/layer/w:0") + def testVarOpScopeUniqueNamesWithJump(self): + with self.test_session(): + with variable_scope.variable_scope("default") as default: + with variable_scope.variable_scope(None, "layer"): + self.assertEqual( + variable_scope.get_variable("w", []).name, + "default/layer/w:0") + with variable_scope.variable_scope(None, "layer"): + self.assertEqual( + variable_scope.get_variable("w", []).name, + "default/layer_1/w:0") + with variable_scope.variable_scope(default): + pass + # No matter the jump in the middle, unique numbering continues. + with variable_scope.variable_scope(None, "layer"): + self.assertEqual( + variable_scope.get_variable("w", []).name, + "default/layer_2/w:0") + def testVarOpScopeReuse(self): with self.test_session(): with variable_scope.variable_scope("outer") as outer: diff --git a/tensorflow/python/kernel_tests/weights_broadcast_test.py b/tensorflow/python/kernel_tests/weights_broadcast_test.py new file mode 100644 index 00000000000..eda2856e0bd --- /dev/null +++ b/tensorflow/python/kernel_tests/weights_broadcast_test.py @@ -0,0 +1,276 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for broadcast rules.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.platform import test + + +def _test_values(shape): + return np.reshape(np.cumsum(np.ones(shape), dtype=np.int32), newshape=shape) + + +class AssertBroadcastableTest(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def _test_valid(self, weights, values): + static_op = weights_broadcast_ops.assert_broadcastable( + weights=weights, values=values) + weights_placeholder = array_ops.placeholder(dtypes_lib.float32) + values_placeholder = array_ops.placeholder(dtypes_lib.float32) + dynamic_op = weights_broadcast_ops.assert_broadcastable( + weights=weights_placeholder, values=values_placeholder) + with self.test_session(): + static_op.run() + dynamic_op.run(feed_dict={ + weights_placeholder: weights, + values_placeholder: values, + }) + + def testScalar(self): + self._test_valid(weights=5, values=_test_values((3, 2, 4))) + + def test1x1x1(self): + self._test_valid( + weights=np.asarray((5,)).reshape((1, 1, 1)), + values=_test_values((3, 2, 4))) + + def test1x1xN(self): + self._test_valid( + weights=np.asarray((5, 7, 11, 3)).reshape((1, 1, 4)), + values=_test_values((3, 2, 4))) + + def test1xNx1(self): + self._test_valid( + weights=np.asarray((5, 11)).reshape((1, 2, 1)), + values=_test_values((3, 2, 4))) + + def test1xNxN(self): + self._test_valid( + weights=np.asarray((5, 7, 11, 3, 2, 13, 7, 5)).reshape((1, 2, 4)), + values=_test_values((3, 2, 4))) + + def testNx1x1(self): + self._test_valid( + weights=np.asarray((5, 7, 11)).reshape((3, 1, 1)), + values=_test_values((3, 2, 4))) + + def testNx1xN(self): + self._test_valid( + weights=np.asarray(( + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3)).reshape((3, 1, 4)), + values=_test_values((3, 2, 4))) + + def testNxNxN(self): + self._test_valid( + weights=np.asarray(( + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, + 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4)), + values=_test_values((3, 2, 4))) + + def _test_invalid(self, weights, values): + error_msg = 'weights can not be broadcast to values' + with self.assertRaisesRegexp(ValueError, error_msg): + weights_broadcast_ops.assert_broadcastable(weights=weights, values=values) + weights_placeholder = array_ops.placeholder(dtypes_lib.float32) + values_placeholder = array_ops.placeholder(dtypes_lib.float32) + dynamic_op = weights_broadcast_ops.assert_broadcastable( + weights=weights_placeholder, values=values_placeholder) + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.OpError, error_msg): + dynamic_op.run(feed_dict={ + weights_placeholder: weights, + values_placeholder: values, + }) + + def testInvalid1(self): + self._test_invalid(weights=np.asarray((5,)), values=_test_values((3, 2, 4))) + + def testInvalid1x1(self): + self._test_invalid( + weights=np.asarray((5,)).reshape((1, 1)), + values=_test_values((3, 2, 4))) + + def testInvalidPrefixMatch(self): + self._test_invalid( + weights=np.asarray((5, 7, 11, 3, 2, 12)).reshape((3, 2)), + values=_test_values((3, 2, 4))) + + def testInvalidSuffixMatch(self): + self._test_invalid( + weights=np.asarray((5, 7, 11, 3, 2, 12, 7, 5)).reshape((2, 4)), + values=_test_values((3, 2, 4))) + + def testInvalidOnesExtraDim(self): + self._test_invalid( + weights=np.asarray((5,)).reshape((1, 1, 1, 1)), + values=_test_values((3, 2, 4))) + + def testInvalidPrefixMatchExtraDim(self): + self._test_invalid( + weights=np.asarray(( + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, + 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4, 1)), + values=_test_values((3, 2, 4))) + + def testInvalidSuffixMatchExtraDim(self): + self._test_invalid( + weights=np.asarray(( + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, + 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((1, 3, 2, 4)), + values=_test_values((3, 2, 4))) + + +class BroadcastWeightsTest(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def _test_valid(self, weights, values, expected): + static_op = weights_broadcast_ops.broadcast_weights( + weights=weights, values=values) + weights_placeholder = array_ops.placeholder(dtypes_lib.float32) + values_placeholder = array_ops.placeholder(dtypes_lib.float32) + dynamic_op = weights_broadcast_ops.broadcast_weights( + weights=weights_placeholder, values=values_placeholder) + with self.test_session(): + self.assertAllEqual(expected, static_op.eval()) + self.assertAllEqual(expected, dynamic_op.eval(feed_dict={ + weights_placeholder: weights, + values_placeholder: values, + })) + + def testScalar(self): + self._test_valid( + weights=5, + values=_test_values((3, 2, 4)), + expected=5 * np.ones((3, 2, 4))) + + def test1x1x1(self): + self._test_valid( + weights=np.asarray((5,)).reshape((1, 1, 1)), + values=_test_values((3, 2, 4)), + expected=5 * np.ones((3, 2, 4))) + + def test1x1xN(self): + weights = np.asarray((5, 7, 11, 3)).reshape((1, 1, 4)) + self._test_valid( + weights=weights, + values=_test_values((3, 2, 4)), + expected=np.tile(weights, reps=(3, 2, 1))) + + def test1xNx1(self): + weights = np.asarray((5, 11)).reshape((1, 2, 1)) + self._test_valid( + weights=weights, + values=_test_values((3, 2, 4)), + expected=np.tile(weights, reps=(3, 1, 4))) + + def test1xNxN(self): + weights = np.asarray((5, 7, 11, 3, 2, 13, 7, 5)).reshape((1, 2, 4)) + self._test_valid( + weights=weights, + values=_test_values((3, 2, 4)), + expected=np.tile(weights, reps=(3, 1, 1))) + + def testNx1x1(self): + weights = np.asarray((5, 7, 11)).reshape((3, 1, 1)) + self._test_valid( + weights=weights, + values=_test_values((3, 2, 4)), + expected=np.tile(weights, reps=(1, 2, 4))) + + def testNx1xN(self): + weights = np.asarray(( + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3)).reshape((3, 1, 4)) + self._test_valid( + weights=weights, + values=_test_values((3, 2, 4)), + expected=np.tile(weights, reps=(1, 2, 1))) + + def testNxNxN(self): + weights = np.asarray(( + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, + 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4)) + self._test_valid( + weights=weights, values=_test_values((3, 2, 4)), expected=weights) + + def _test_invalid(self, weights, values): + error_msg = 'weights can not be broadcast to values' + with self.assertRaisesRegexp(ValueError, error_msg): + weights_broadcast_ops.broadcast_weights(weights=weights, values=values) + weights_placeholder = array_ops.placeholder(dtypes_lib.float32) + values_placeholder = array_ops.placeholder(dtypes_lib.float32) + dynamic_op = weights_broadcast_ops.broadcast_weights( + weights=weights_placeholder, values=values_placeholder) + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.OpError, error_msg): + dynamic_op.eval(feed_dict={ + weights_placeholder: weights, + values_placeholder: values, + }) + + def testInvalid1(self): + self._test_invalid(weights=np.asarray((5,)), values=_test_values((3, 2, 4))) + + def testInvalid1x1(self): + self._test_invalid( + weights=np.asarray((5,)).reshape((1, 1)), + values=_test_values((3, 2, 4))) + + def testInvalidPrefixMatch(self): + self._test_invalid( + weights=np.asarray((5, 7, 11, 3, 2, 12)).reshape((3, 2)), + values=_test_values((3, 2, 4))) + + def testInvalidSuffixMatch(self): + self._test_invalid( + weights=np.asarray((5, 7, 11, 3, 2, 12, 7, 5)).reshape((2, 4)), + values=_test_values((3, 2, 4))) + + def testInvalidOnesExtraDim(self): + self._test_invalid( + weights=np.asarray((5,)).reshape((1, 1, 1, 1)), + values=_test_values((3, 2, 4))) + + def testInvalidPrefixMatchExtraDim(self): + self._test_invalid( + weights=np.asarray(( + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, + 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4, 1)), + values=_test_values((3, 2, 4))) + + def testInvalidSuffixMatchExtraDim(self): + self._test_invalid( + weights=np.asarray(( + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, + 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((1, 3, 2, 4)), + values=_test_values((3, 2, 4))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 1201c0705b1..fc47fc325f3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -742,194 +742,6 @@ def _SliceHelperVar(var, slice_spec): ops.Tensor._override_operator("__getitem__", _SliceHelper) -def _inplace_helper(value, loc, update, op): - """Applies an inplace op on `value` at `loc` with `update`. - - op is one of gen_array_ops._inplace_update, gen_array_ops._inplace_add, or - gen_array_ops._inplace_subtract. - - If `loc` is None, `value` and `update` must be the same size. - ``` - value op update - ``` - - If `loc` is a scalar, `value` has rank 1 higher than `update` - ``` - value[i, :] op update - ``` - - If `loc` is a vector, `value` has the same rank as `update` - ``` - value[loc, :] op update - ``` - - Args: - value: A `Tensor` object that will be updated in-place. - loc: None, scalar or 1-D `Tensor`. - update: A `Tensor` of rank one less than `value` if `loc` is a scalar, - otherwise of rank equal to `value` that contains the new values - for `value`. - op: One of gen_array_ops._inplace_update, ._inplace_add, ._inplace_subtract - Returns: - output: `value` that has been updated accordingly. - """ - - value = ops.convert_to_tensor(value) - update = ops.convert_to_tensor(update, value.dtype) - - if loc is None: - # Full tensor - return reshape(op(reshape(value, [1, -1]), - gen_math_ops.cast([0], dtypes.int64), - reshape(update, [1, -1])), - shape(value)) - - loc = gen_math_ops.cast(loc, dtypes.int64) - - if loc.get_shape().ndims == 0: - # Single 0-dim update - return op(value, reshape(loc, [1]), expand_dims(update, 0)) - - return op(value, loc, update) - - -def _empty(output_shape, dtype, init=False): - """Creates an empty Tensor with shape `output_shape` and type `dtype`. - - The memory can optionally be initialized. This is usually useful in - conjunction with in-place operations. - - Args: - output_shape: 1-D `Tensor` indicating the shape of the output. - dtype: The element type of the returned tensor. - init: `bool` indicating whether or not to zero the allocated memory. - - Returns: - output: An empty Tensor of the specified type. - """ - return gen_array_ops._empty(output_shape, dtype, init=init) - - -def _empty_like(value, init=None): - """Creates an empty Tensor with the same shape and type `dtype` as value. - - The memory can optionally be initialized. This op is usually useful in - conjunction with in-place operations. - - Args: - value: A `Tensor` whose shape will be used. - init: Initalize the returned tensor with the default value of - `value.dtype()` if True. Otherwise do not initialize. - - Returns: - output: An empty Tensor of the specified shape and type. - """ - value = ops.convert_to_tensor(value) - return gen_array_ops._empty(shape(value), value.dtype, init=init) - - -def _alias_inplace_update(value, loc, update): - """Updates input `value` at `loc` with `update`. Aliases value. - - If `loc` is None, `value` and `update` must be the same size. - ``` - value = update - ``` - - If `loc` is a scalar, `value` has rank 1 higher than `update` - ``` - value[i, :] = update - ``` - - If `loc` is a vector, `value` has the same rank as `update` - ``` - value[loc, :] = update - ``` - - Warning: If you use this function you will almost certainly want to add - a control dependency as done in the implementation of parallel_stack to - avoid race conditions. - Args: - value: A `Tensor` object that will be updated in-place. - loc: None, scalar or 1-D `Tensor`. - update: A `Tensor` of rank one less than `value` if `loc` is a scalar, - otherwise of rank equal to `value` that contains the new values - for `value`. - Returns: - output: `value` that has been updated accordingly. - """ - - return _inplace_helper(value, loc, update, gen_array_ops._inplace_update) - - -def _alias_inplace_add(value, loc, update): - """Updates input `value` at `loc` with `update`. Aliases value. - - If `loc` is None, `value` and `update` must be the same size. - ``` - value += update - ``` - - If `loc` is a scalar, `value` has rank 1 higher than `update` - ``` - value[i, :] += update - ``` - - If `loc` is a vector, `value` has the same rank as `update` - ``` - value[loc, :] += update - ``` - - Warning: If you use this function you will almost certainly want to add - a control dependency as done in the implementation of parallel_stack to - avoid race conditions. - Args: - value: A `Tensor` object that will be updated in-place. - loc: None, scalar or 1-D `Tensor`. - update: A `Tensor` of rank one less than `value` if `loc` is a scalar, - otherwise of rank equal to `value` that contains the new values - for `value`. - Returns: - output: `value` that has been updated accordingly. - """ - - return _inplace_helper(value, loc, update, gen_array_ops._inplace_add) - - -def _alias_inplace_subtract(value, loc, update): - """Updates input `value` at `loc` with `update`. Aliases value. - - If `loc` is None, `value` and `update` must be the same size. - ``` - value -= update - ``` - - If `loc` is a scalar, `value` has rank 1 higher than `update` - ``` - value[i, :] -= update - ``` - - If `loc` is a vector, `value` has the same rank as `update` - ``` - value[loc, :] -= update - ``` - - Warning: If you use this function you will almost certainly want to add - a control dependency as done in the implementation of parallel_stack to - avoid race conditions. - Args: - value: A `Tensor` object that will be updated in-place. - loc: None, Scalar or 1-D `Tensor`. - update: A `Tensor` of rank one less than `value` if `loc` is a scalar, - otherwise of rank equal to `value` that contains the new values - for `value`. - Returns: - output: `value` that has been updated accordingly. - """ - - return _inplace_helper(value, loc, update, gen_array_ops._inplace_subtract) - - def parallel_stack(values, name="parallel_stack"): """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel. @@ -972,16 +784,9 @@ def parallel_stack(values, name="parallel_stack"): output_shape = tensor_shape.TensorShape([len(values)]) output_shape = output_shape.concatenate(value_shape) - - outputs = _empty(output_shape, values[0].dtype) - output_ops = [] - for i in range(len(values)): - with ops.colocate_with(outputs): - output_op = _alias_inplace_update(outputs, i, values[i]) - output_ops.append(output_op) - with ops.control_dependencies(output_ops): - outputs = identity(outputs) - return outputs + # expand_dims converts concat to stack. + return gen_array_ops._parallel_concat( + [expand_dims(value, 0) for value in values], shape=output_shape) def stack(values, axis=0, name="stack"): """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. diff --git a/tensorflow/python/ops/cloud/__init__.py b/tensorflow/python/ops/cloud/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/python/ops/cloud/bigquery_reader_ops.py b/tensorflow/python/ops/cloud/bigquery_reader_ops.py new file mode 100644 index 00000000000..7786aea025a --- /dev/null +++ b/tensorflow/python/ops/cloud/bigquery_reader_ops.py @@ -0,0 +1,157 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""BigQuery reading support for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_cloud_ops +from tensorflow.python.ops import io_ops + + +class BigQueryReader(io_ops.ReaderBase): + """A Reader that outputs keys and tf.Example values from a BigQuery table. + + Note(1): This op is currently not linked into the binary. It will be linked + by default after more perf testing. + + Note(2): This op currently returns example proto as its output. This is not + final and we are experimenting with adding support for returning csv. Support + for example proto may be deprecated after that. + + Example use: + ```python + # Assume a BigQuery has the following schema, + # name STRING, + # age INT, + # state STRING + + # Create the parse_examples list of features. + features = dict( + name=tf.FixedLenFeature([1], tf.string), + age=tf.FixedLenFeature([1], tf.int32), + state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK")) + + # Create a Reader. + reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT, + dataset_id=DATASET, + table_id=TABLE, + timestamp_millis=TIME, + num_partitions=NUM_PARTITIONS, + features=features) + + # Populate a queue with the BigQuery Table partitions. + queue = tf.training.string_input_producer(reader.partitions()) + + # Read and parse examples. + row_id, examples_serialized = reader.read(queue) + examples = tf.parse_example(examples_serialized, features=features) + + # Process the Tensors examples["name"], examples["age"], etc... + ``` + + Note that to create a reader a snapshot timestamp is necessary. This + will enable the reader to look at a consistent snapshot of the table. + For more information, see 'Table Decorators' in BigQuery docs. + + See ReaderBase for supported methods. + """ + + def __init__(self, + project_id, + dataset_id, + table_id, + timestamp_millis, + num_partitions, + features=None, + columns=None, + test_end_point=None, + name=None): + """Creates a BigQueryReader. + + Args: + project_id: GCP project ID. + dataset_id: BigQuery dataset ID. + table_id: BigQuery table ID. + timestamp_millis: timestamp to snapshot the table in milliseconds since + the epoch. Relative (negative or zero) snapshot times are not allowed. + For more details, see 'Table Decorators' in BigQuery docs. + num_partitions: Number of non-overlapping partitions to read from. + features: parse_example compatible dict from keys to `VarLenFeature` and + `FixedLenFeature` objects. Keys are read as columns from the db. + columns: list of columns to read, can be set iff features is None. + test_end_point: Used only for testing purposes (optional). + name: a name for the operation (optional). + + Raises: + TypeError: - If features is neither None nor a dict or + - If columns is is neither None nor a list or + - If both features and columns are None or set. + """ + if (features is None) == (columns is None): + raise TypeError("exactly one of features and columns must be set.") + + if features is not None: + if not isinstance(features, dict): + raise TypeError("features must be a dict.") + self._columns = list(features.keys()) + elif columns is not None: + if not isinstance(columns, list): + raise TypeError("columns must be a list.") + self._columns = columns + + self._project_id = project_id + self._dataset_id = dataset_id + self._table_id = table_id + self._timestamp_millis = timestamp_millis + self._num_partitions = num_partitions + self._test_end_point = test_end_point + + reader = gen_cloud_ops.big_query_reader( + name=name, + project_id=self._project_id, + dataset_id=self._dataset_id, + table_id=self._table_id, + timestamp_millis=self._timestamp_millis, + columns=self._columns, + test_end_point=self._test_end_point) + super(BigQueryReader, self).__init__(reader) + + def partitions(self, name=None): + """Returns serialized BigQueryTablePartition messages. + + These messages represent a non-overlapping division of a table for a + bulk read. + + Args: + name: a name for the operation (optional). + + Returns: + `1-D` string `Tensor` of serialized `BigQueryTablePartition` messages. + """ + return gen_cloud_ops.generate_big_query_reader_partitions( + name=name, + project_id=self._project_id, + dataset_id=self._dataset_id, + table_id=self._table_id, + timestamp_millis=self._timestamp_millis, + num_partitions=self._num_partitions, + test_end_point=self._test_end_point, + columns=self._columns) + + +ops.NotDifferentiable("BigQueryReader") diff --git a/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py b/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py new file mode 100644 index 00000000000..196991f68a7 --- /dev/null +++ b/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py @@ -0,0 +1,282 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BigQueryReader Op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import re +import threading + +from six.moves import SimpleHTTPServer +from six.moves import socketserver + +from tensorflow.core.example import example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops.cloud import cloud +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat + +_PROJECT = "test-project" +_DATASET = "test-dataset" +_TABLE = "test-table" +# List representation of the test rows in the 'test-table' in BigQuery. +# The schema for each row is: [int64, string, float]. +# The values for rows are generated such that some columns have null values. The +# general formula here is: +# - The int64 column is present in every row. +# - The string column is only avaiable in even rows. +# - The float column is only available in every third row. +_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1], + [4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None], + [8, "s_8", None], [9, None, 9.1]] +# Schema for 'test-table'. +# The schema currently has three columns: int64, string, and float +_SCHEMA = { + "kind": "bigquery#table", + "id": "test-project:test-dataset.test-table", + "schema": { + "fields": [{ + "name": "int64_col", + "type": "INTEGER", + "mode": "NULLABLE" + }, { + "name": "string_col", + "type": "STRING", + "mode": "NULLABLE" + }, { + "name": "float_col", + "type": "FLOAT", + "mode": "NULLABLE" + }] + } +} + + +def _ConvertRowToExampleProto(row): + """Converts the input row to an Example proto. + + Args: + row: Input Row instance. + + Returns: + An Example proto initialized with row values. + """ + + example = example_pb2.Example() + example.features.feature["int64_col"].int64_list.value.append(row[0]) + if row[1] is not None: + example.features.feature["string_col"].bytes_list.value.append( + compat.as_bytes(row[1])) + if row[2] is not None: + example.features.feature["float_col"].float_list.value.append(row[2]) + return example + + +class FakeBigQueryServer(threading.Thread): + """Fake http server to return schema and data for sample table.""" + + def __init__(self, address, port): + """Creates a FakeBigQueryServer. + + Args: + address: Server address + port: Server port. Pass 0 to automatically pick an empty port. + """ + threading.Thread.__init__(self) + self.handler = BigQueryRequestHandler + self.httpd = socketserver.TCPServer((address, port), self.handler) + + def run(self): + self.httpd.serve_forever() + + def shutdown(self): + self.httpd.shutdown() + self.httpd.socket.close() + + +class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): + """Responds to BigQuery HTTP requests. + + Attributes: + num_rows: num_rows in the underlying table served by this class. + """ + + num_rows = 0 + + def do_GET(self): + if "data?maxResults=" not in self.path: + # This is a schema request. + _SCHEMA["numRows"] = self.num_rows + response = json.dumps(_SCHEMA) + else: + # This is a data request. + # + # Extract max results and start index. + max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0]) + start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0]) + + # Send the rows as JSON. + rows = [] + for row in _ROWS[start_index:start_index + max_results]: + row_json = { + "f": [{ + "v": str(row[0]) + }, { + "v": str(row[1]) if row[1] is not None else None + }, { + "v": str(row[2]) if row[2] is not None else None + }] + } + rows.append(row_json) + response = json.dumps({ + "kind": "bigquery#table", + "id": "test-project:test-dataset.test-table", + "rows": rows + }) + self.send_response(200) + self.end_headers() + self.wfile.write(compat.as_bytes(response)) + + +def _SetUpQueue(reader): + """Sets up a queue for a reader.""" + queue = data_flow_ops.FIFOQueue(8, [types_pb2.DT_STRING], shapes=()) + key, value = reader.read(queue) + queue.enqueue_many(reader.partitions()).run() + queue.close().run() + return key, value + + +class BigQueryReaderOpsTest(test.TestCase): + + def setUp(self): + super(BigQueryReaderOpsTest, self).setUp() + self.server = FakeBigQueryServer("127.0.0.1", 0) + self.server.start() + logging.info("server address is %s:%s", self.server.httpd.server_address[0], + self.server.httpd.server_address[1]) + + def tearDown(self): + self.server.shutdown() + super(BigQueryReaderOpsTest, self).tearDown() + + def _ReadAndCheckRowsUsingFeatures(self, num_rows): + self.server.handler.num_rows = num_rows + + with self.test_session() as sess: + feature_configs = { + "int64_col": + parsing_ops.FixedLenFeature( + [1], dtype=dtypes.int64), + "string_col": + parsing_ops.FixedLenFeature( + [1], dtype=dtypes.string, default_value="s_default"), + } + reader = cloud.BigQueryReader( + project_id=_PROJECT, + dataset_id=_DATASET, + table_id=_TABLE, + num_partitions=4, + features=feature_configs, + timestamp_millis=1, + test_end_point=("%s:%s" % (self.server.httpd.server_address[0], + self.server.httpd.server_address[1]))) + + key, value = _SetUpQueue(reader) + + seen_rows = [] + features = parsing_ops.parse_example( + array_ops.reshape(value, [1]), feature_configs) + for _ in range(num_rows): + int_value, str_value = sess.run( + [features["int64_col"], features["string_col"]]) + + # Parse values returned from the session. + self.assertEqual(int_value.shape, (1, 1)) + self.assertEqual(str_value.shape, (1, 1)) + int64_col = int_value[0][0] + string_col = str_value[0][0] + seen_rows.append(int64_col) + + # Compare. + expected_row = _ROWS[int64_col] + self.assertEqual(int64_col, expected_row[0]) + self.assertEqual( + compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1] + else "s_default") + + self.assertItemsEqual(seen_rows, range(num_rows)) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + sess.run([key, value]) + + def testReadingSingleRowUsingFeatures(self): + self._ReadAndCheckRowsUsingFeatures(1) + + def testReadingMultipleRowsUsingFeatures(self): + self._ReadAndCheckRowsUsingFeatures(10) + + def testReadingMultipleRowsUsingColumns(self): + num_rows = 10 + self.server.handler.num_rows = num_rows + + with self.test_session() as sess: + reader = cloud.BigQueryReader( + project_id=_PROJECT, + dataset_id=_DATASET, + table_id=_TABLE, + num_partitions=4, + columns=["int64_col", "float_col", "string_col"], + timestamp_millis=1, + test_end_point=("%s:%s" % (self.server.httpd.server_address[0], + self.server.httpd.server_address[1]))) + key, value = _SetUpQueue(reader) + seen_rows = [] + for row_index in range(num_rows): + returned_row_id, example_proto = sess.run([key, value]) + example = example_pb2.Example() + example.ParseFromString(example_proto) + self.assertIn("int64_col", example.features.feature) + feature = example.features.feature["int64_col"] + self.assertEqual(len(feature.int64_list.value), 1) + int64_col = feature.int64_list.value[0] + seen_rows.append(int64_col) + + # Create our expected Example. + expected_example = example_pb2.Example() + expected_example = _ConvertRowToExampleProto(_ROWS[int64_col]) + + # Compare. + self.assertProtoEquals(example, expected_example) + self.assertEqual(row_index, int(returned_row_id)) + + self.assertItemsEqual(seen_rows, range(num_rows)) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + sess.run([key, value]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/cloud/cloud.py b/tensorflow/python/ops/cloud/cloud.py new file mode 100644 index 00000000000..eb917a987e9 --- /dev/null +++ b/tensorflow/python/ops/cloud/cloud.py @@ -0,0 +1,31 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Import cloud ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.cloud.bigquery_reader_ops import * +# pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['BigQueryReader'] +remove_undocumented(__name__, _allowed_symbols, [sys.modules[__name__]]) diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 16a9f5d96f7..72f0454e30c 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import math_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * # pylint: enable=wildcard-import +from tensorflow.python.util.deprecation import deprecated def _as_type_list(dtypes): @@ -1053,9 +1054,23 @@ class Barrier(object): self._barrier_ref, name=name) +@deprecated("2017-03-02", "Use `tf.tables_initializer` instead.") def initialize_all_tables(name="init_all_tables"): """Returns an Op that initializes all tables of the default graph. + Args: + name: Optional name for the initialization op. + + Returns: + An Op that initializes all tables. Note that if there are + not tables the returned Op is a NoOp. + """ + return tables_initializer(name) + + +def tables_initializer(name="init_all_tables"): + """Returns an Op that initializes all tables of the default graph. + Args: name: Optional name for the initialization op. diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index abe5c538d0a..4b1b9815ca8 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -18,6 +18,7 @@ MirrorPadGrad OneHot Pack Pad +ParallelConcat Placeholder RefIdentity Reverse diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index c5adcf609d1..a00625d0835 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops import weights_broadcast_ops def _local_variable(initial_value, validate_shape=True, name=None): @@ -171,41 +172,6 @@ def _create_local(name, shape, collections=None, validate_shape=True, validate_shape=validate_shape) -def _assert_weights_rank(weights, values): - return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) - - -def _broadcast_weights(weights, values): - """Broadcast `weights` to the same shape as `values`. - - This returns a version of `weights` following the same broadcast rules as - `multiply(weights, values)`. When computing a weighted average, use this - function to broadcast `weights` before summing them; e.g., - `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. - - Args: - weights: `Tensor` whose rank is either 0, or the same rank as `values`, and - must be broadcastable to `values` (i.e., all dimensions must be either - `1`, or the same as the corresponding `values` dimension). - values: `Tensor` of any shape. - - Returns: - `weights` broadcast to `values` shape. - - Raises: - ValueError: if `weights` rank is invalid. - """ - weights_shape = weights.get_shape() - values_shape = values.get_shape() - if (weights_shape.is_fully_defined() and - values_shape.is_fully_defined() and - weights_shape.is_compatible_with(values_shape)): - return weights - with ops.control_dependencies((_assert_weights_rank(weights, values),)): - return math_ops.multiply( - weights, array_ops.ones_like(values), name='broadcast_weights') - - def _safe_div(numerator, denominator, name): """Divides two values, returning 0 if the denominator is <= 0. @@ -292,7 +258,8 @@ def mean(values, weights=None, metrics_collections=None, if weights is None: num_values = math_ops.to_float(array_ops.size(values)) else: - weights = _broadcast_weights(math_ops.to_float(weights), values) + weights = weights_broadcast_ops.broadcast_weights( + math_ops.to_float(weights), values) values = math_ops.multiply(values, weights) num_values = math_ops.reduce_sum(weights) @@ -451,7 +418,8 @@ def _confusion_matrix_at_thresholds( label_is_neg = math_ops.logical_not(label_is_pos) if weights is not None: - weights = _broadcast_weights(math_ops.to_float(weights), predictions) + weights = weights_broadcast_ops.broadcast_weights( + math_ops.to_float(weights), predictions) weights_tiled = array_ops.tile(array_ops.reshape( weights, [1, -1]), [num_thresholds, 1]) thresh_tiled.get_shape().assert_is_compatible_with( @@ -1002,7 +970,8 @@ def mean_tensor(values, weights=None, metrics_collections=None, num_values = array_ops.ones_like(values) if weights is not None: - weights = _broadcast_weights(math_ops.to_float(weights), values) + weights = weights_broadcast_ops.broadcast_weights( + math_ops.to_float(weights), values) values = math_ops.multiply(values, weights) num_values = math_ops.multiply(num_values, weights) @@ -1580,7 +1549,8 @@ def _sparse_true_positive_at_k(labels, tp = sets.set_size(sets.set_intersection(predictions_idx, labels)) tp = math_ops.to_double(tp) if weights is not None: - with ops.control_dependencies((_assert_weights_rank(weights, tp),)): + with ops.control_dependencies(( + weights_broadcast_ops.assert_broadcastable(weights, tp),)): weights = math_ops.to_double(weights) tp = math_ops.multiply(tp, weights) return tp @@ -1675,7 +1645,8 @@ def _sparse_false_negative_at_k(labels, aminusb=False)) fn = math_ops.to_double(fn) if weights is not None: - with ops.control_dependencies((_assert_weights_rank(weights, fn),)): + with ops.control_dependencies(( + weights_broadcast_ops.assert_broadcastable(weights, fn),)): weights = math_ops.to_double(weights) fn = math_ops.multiply(fn, weights) return fn @@ -2292,7 +2263,7 @@ def sparse_average_precision_at_k(labels, average_precision = _sparse_average_precision_at_k( predictions=predictions, labels=labels, k=k) if weights is not None: - weights = _broadcast_weights( + weights = weights_broadcast_ops.broadcast_weights( math_ops.to_double(weights), average_precision) average_precision = math_ops.multiply(average_precision, weights) @@ -2367,7 +2338,8 @@ def _sparse_false_positive_at_k(labels, predictions_idx, labels, aminusb=True)) fp = math_ops.to_double(fp) if weights is not None: - with ops.control_dependencies((_assert_weights_rank(weights, fp),)): + with ops.control_dependencies(( + weights_broadcast_ops.assert_broadcastable(weights, fp),)): weights = math_ops.to_double(weights) fp = math_ops.multiply(fp, weights) return fp diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 3c7182f7dc1..ff1d2a6951b 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -106,6 +106,7 @@ automatically by the optimizers in most cases. ### Read-only Lookup Tables @@initialize_all_tables +@@tables_initializer ## Exporting and Importing Meta Graphs diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 2b007d545b7..ddba73f7e9a 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections as collections_lib import contextlib +import copy import functools import traceback @@ -182,21 +183,21 @@ class _VariableStore(object): """Create a variable store.""" self._vars = {} # A dictionary of the stored TensorFlow variables. self._partitioned_vars = {} # A dict of the stored PartitionedVariables. - self._variable_scopes_count = {} # Count re-used variable scopes. + self.variable_scopes_count = {} # Count re-used variable scopes. def open_variable_scope(self, scope_name): - if scope_name in self._variable_scopes_count: - self._variable_scopes_count[scope_name] += 1 + if scope_name in self.variable_scopes_count: + self.variable_scopes_count[scope_name] += 1 else: - self._variable_scopes_count[scope_name] = 1 + self.variable_scopes_count[scope_name] = 1 def close_variable_subscopes(self, scope_name): - for k in self._variable_scopes_count: + for k in self.variable_scopes_count: if not scope_name or k.startswith(scope_name + "/"): - self._variable_scopes_count[k] = 0 + self.variable_scopes_count[k] = 0 def variable_scope_count(self, scope_name): - return self._variable_scopes_count.get(scope_name, 0) + return self.variable_scopes_count.get(scope_name, 0) def get_variable(self, name, shape=None, dtype=dtypes.float32, initializer=None, regularizer=None, reuse=None, @@ -1222,6 +1223,7 @@ def _pure_variable_scope(name_or_scope, try: var_store.open_variable_scope(new_name) if isinstance(name_or_scope, VariableScope): + old_subscopes = copy.copy(var_store.variable_scopes_count) name_scope = name_or_scope._name_scope # pylint: disable=protected-access # Handler for the case when we jump to a shared scope. # We create a new VariableScope (default_varscope[0]) that contains @@ -1280,6 +1282,9 @@ def _pure_variable_scope(name_or_scope, yield default_varscope[0] finally: var_store.close_variable_subscopes(new_name) + # If jumping out from a non-prolonged scope, restore counts. + if isinstance(name_or_scope, VariableScope): + var_store.variable_scopes_count = old_subscopes default_varscope[0] = old diff --git a/tensorflow/python/ops/weights_broadcast_ops.py b/tensorflow/python/ops/weights_broadcast_ops.py new file mode 100644 index 00000000000..d41f5ca54b3 --- /dev/null +++ b/tensorflow/python/ops/weights_broadcast_ops.py @@ -0,0 +1,168 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Weight broadcasting operations. + +In `tf.losses` and `tf.metrics`, we support limited weight broadcasting. This +file includes operations for those broadcasting rules. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sets + + +def _has_valid_dims(weights_shape, values_shape): + with ops.name_scope( + None, "has_invalid_dims", (weights_shape, values_shape)) as scope: + values_shape_2d = array_ops.expand_dims(values_shape, -1) + valid_dims = array_ops.concat_v2( + (values_shape_2d, array_ops.ones_like(values_shape_2d)), axis=1) + weights_shape_2d = array_ops.expand_dims(weights_shape, -1) + invalid_dims = sets.set_difference(weights_shape_2d, valid_dims) + num_invalid_dims = array_ops.size( + invalid_dims.values, name="num_invalid_dims") + return math_ops.equal(0, num_invalid_dims, name=scope) + + +def _has_valid_nonscalar_shape( + weights_rank, weights_shape, values_rank, values_shape): + with ops.name_scope( + None, "has_valid_nonscalar_shape", + (weights_rank, weights_shape, values_rank, values_shape)) as scope: + is_same_rank = math_ops.equal( + values_rank, weights_rank, name="is_same_rank") + return control_flow_ops.cond( + is_same_rank, + lambda: _has_valid_dims(weights_shape, values_shape), + lambda: is_same_rank, + name=scope) + + +_ASSERT_BROADCASTABLE_ERROR_PREFIX = "weights can not be broadcast to values." + + +def assert_broadcastable(weights, values): + """Asserts `weights` can be broadcast to `values`. + + In `tf.losses` and `tf.metrics`, we support limited weight broadcasting. We + let weights be either scalar, or the same rank as the target values, with each + dimension either 1, or the same as the corresponding values dimension. + + Args: + weights: `Tensor` of weights. + values: `Tensor` of values to which weights are applied. + + Returns: + `Operation` raising `InvalidArgumentError` if `weights` has incorrect shape. + `no_op` if static checks determine `weights` has correct shape. + + Raises: + ValueError: If static checks determine `weights` has incorrect shape. + """ + with ops.name_scope(None, "assert_broadcastable", (weights, values)) as scope: + with ops.name_scope(None, "weights", (weights,)) as weights_scope: + weights = ops.convert_to_tensor(weights, name=weights_scope) + weights_shape = array_ops.shape(weights, name="shape") + weights_rank = array_ops.rank(weights, name="rank") + weights_rank_static = tensor_util.constant_value(weights_rank) + + with ops.name_scope(None, "values", (values,)) as values_scope: + values = ops.convert_to_tensor(values, name=values_scope) + values_shape = array_ops.shape(values, name="shape") + values_rank = array_ops.rank(values, name="rank") + values_rank_static = tensor_util.constant_value(values_rank) + + # Try static checks. + if weights_rank_static is not None and values_rank_static is not None: + if weights_rank_static == 0: + return control_flow_ops.no_op(name="static_scalar_check_success") + if weights_rank_static != values_rank_static: + raise ValueError( + "%s values.rank=%s. weights.rank=%s." % ( + _ASSERT_BROADCASTABLE_ERROR_PREFIX, values_rank_static, + weights_rank_static)) + weights_shape_static = tensor_util.constant_value(weights_shape) + values_shape_static = tensor_util.constant_value(values_shape) + if weights_shape_static is not None and values_shape_static is not None: + # Sanity check, this should always be true since we checked rank above. + ndims = len(values_shape_static) + assert ndims == len(weights_shape_static) + + for i in range(ndims): + if weights_shape_static[i] not in (1, values_shape_static[i]): + raise ValueError( + "%s Mismatch at dim %s. values.shape=%s weights.shape=%s." % ( + _ASSERT_BROADCASTABLE_ERROR_PREFIX, i, values_shape_static, + weights_shape_static)) + return control_flow_ops.no_op(name="static_dims_check_success") + + # Dynamic checks. + is_scalar = math_ops.equal(0, weights_rank, name="is_scalar") + data = ( + _ASSERT_BROADCASTABLE_ERROR_PREFIX, + "weights.shape=", weights.name, weights_shape, + "values.shape=", values.name, values_shape, + "is_scalar=", is_scalar, + ) + is_valid_shape = control_flow_ops.cond( + is_scalar, + lambda: is_scalar, + lambda: _has_valid_nonscalar_shape( # pylint: disable=g-long-lambda + weights_rank, weights_shape, values_rank, values_shape), + name="is_valid_shape") + return control_flow_ops.Assert(is_valid_shape, data, name=scope) + + +def broadcast_weights(weights, values): + """Broadcast `weights` to the same shape as `values`. + + This returns a version of `weights` following the same broadcast rules as + `mul(weights, values)`, but limited to the weights shapes allowed by + `assert_broadcastable`. When computing a weighted average, use this function + to broadcast `weights` before summing them; e.g., + `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. + + Args: + weights: `Tensor` whose shape is broadcastable to `values` according to the + rules of `assert_broadcastable`. + values: `Tensor` of any shape. + + Returns: + `weights` broadcast to `values` shape according to the rules of + `assert_broadcastable`. + """ + with ops.name_scope(None, "broadcast_weights", (weights, values)) as scope: + values = ops.convert_to_tensor(values, name="values") + weights = ops.convert_to_tensor( + weights, dtype=values.dtype.base_dtype, name="weights") + + # Try static check for exact match. + weights_shape = weights.get_shape() + values_shape = values.get_shape() + if (weights_shape.is_fully_defined() and + values_shape.is_fully_defined() and + weights_shape.is_compatible_with(values_shape)): + return weights + + with ops.control_dependencies((assert_broadcastable(weights, values),)): + return math_ops.multiply( + weights, array_ops.ones_like(values), name=scope) diff --git a/tensorflow/python/saved_model/main_op.py b/tensorflow/python/saved_model/main_op.py index 5d8c0db2d83..3f25dc137e3 100644 --- a/tensorflow/python/saved_model/main_op.py +++ b/tensorflow/python/saved_model/main_op.py @@ -39,7 +39,7 @@ def main_op(): """ init = variables.global_variables_initializer() init_local = variables.local_variables_initializer() - init_tables = tf_data_flow_ops.initialize_all_tables() + init_tables = tf_data_flow_ops.tables_initializer() return control_flow_ops.group(init, init_local, init_tables) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index ffdd533fd9a..30b9ccf922b 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -237,7 +237,7 @@ class Scaffold(object): @staticmethod def _default_local_init_op(): return control_flow_ops.group(variables.local_variables_initializer(), - data_flow_ops.initialize_all_tables()) + data_flow_ops.tables_initializer()) def MonitoredTrainingSession(master='', # pylint: disable=invalid-name diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 8e399fb46f8..aa5081870e2 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -440,7 +440,7 @@ class Supervisor(object): ops.GraphKeys.LOCAL_INIT_OP) if local_init_op is None: op_list = [variables.local_variables_initializer(), - data_flow_ops.initialize_all_tables()] + data_flow_ops.tables_initializer()] if op_list: local_init_op = control_flow_ops.group(*op_list) ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html b/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html index 357655e2582..fdf2c4494f7 100644 --- a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html +++ b/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html @@ -16,7 +16,6 @@ limitations under the License. --> <link rel="import" href="../polymer/polymer.html"> -<link rel="import" href="../paper-slider/paper-slider.html"> <link rel="import" href="../tf-imports/lodash.html"> <link rel="import" href="../tf-imports/d3.html"> @@ -29,29 +28,15 @@ future for loading older images. <dom-module id="tf-image-loader"> <template> <div id="image-annotation"> - <template is="dom-if" if="[[_hasAtLeastOneStep]]"> - step - <span class="step-value"> - [[_stepValue]] - </span> - <template is="dom-if" if="[[_currentWallTime]]"> - ([[_currentWallTime]]) - </template> - </template> - <template is="dom-if" if="[[_hasMultipleSteps]]"> - <paper-slider - id="steps" - immediate-value="{{_stepIndex}}" - max="[[_maxStepIndex]]" - max-markers="[[_maxStepIndex]]" - snaps - step="1" - value="{{_stepIndex}}"></paper-slider> + step [[step]] + <template is="dom-if" if="[[wallTime]]"> + ([[wallTime]]) </template> </div> - - <img id="img" src=""> - + <img + id="img" + src="[[imageUrl]]" + on-error="reload"> <style> :host { display: block; @@ -67,21 +52,7 @@ future for loading older images. margin: -10px 0 10px 0; } - #image-annotation .step-value { - font-weight: bold; - } - - #steps { - height: 15px; - margin: 0 0 0 -15px; - /* 31 comes from adding a padding of 15px from both sides of the paper-slider, subtracting - * 1px so that the slider width aligns with the image (the last slider marker takes up 1px), - * and adding 2px to account for a border of 1px on both sides of the image. 30 - 1 + 2. */ - width: calc(100% + 31px); - } - img { - border: 1px solid #f5f5f5; image-rendering: -moz-crisp-edges; image-rendering: pixelated; display: block; @@ -95,106 +66,33 @@ future for loading older images. is: "tf-image-loader", properties: { colorScale: Object, + imageUrl: String, run: String, - // This is an array of Tensorboard Image&Datum objects (See backend.ts for details). The - // properties of objects in this array are - // { - // width: number, - // height: number, - // wall_time: Date, - // step: number, - // url: string, - // } - _steps: { - type: Array, - value: [], - notify: true, - }, - _stepIndex: { - type: Number, - notify: true, - }, - _hasAtLeastOneStep: { - type: Boolean, - computed: "_computeHasAtLeastOneStep(_steps)", - }, - _hasMultipleSteps: { - type: Boolean, - computed: "_computeHasMultipleSteps(_steps)", - }, - _stepValue: { - type: Number, - computed: "_computeStepValue(_stepIndex)", - }, - _currentWallTime: { - type: Number, - computed: "_computeCurrentWallTime(_stepIndex)", - }, - _maxStepIndex: { - type: Number, - computed: "_computeMaxStepIndex(_steps)", - }, - }, - observers: [ - "_onStepIndexChanged(_stepIndex)", - ], - redraw: function() { - // Other dashboards logic requires a redraw method to be defined. redraw is called at - // various places such as when the image is expanded. - this.setSeriesData(this.run, this._steps); + step: Number, + wallTime: String, }, setVisibleSeries: function(runs) { // Do nothing. }, - setSeriesData: function(run, steps) { - this.set("run", run); - this.set("_steps", steps); - this.set("_stepIndex", steps.length - 1); + setSeriesData: function(run, data) { + var last = _.last(data); + this.redraw(last); // Update the border color based on the run. - var color = this.colorScale.scale(run); - this.$$("#image-annotation").style.borderColor = color; + this.$$('#image-annotation').style.borderColor = this.colorScale.scale(run); + }, + redraw: function(imageData) { + var url = imageData.url || this.imageUrl; + this.imageUrl = ""; // Force redraw + this.imageUrl = url; - // Set the color for the slider that lets the user select a step. - // These values should all be changed from their defaults set by paper-slider. - var mixins = [ - "--paper-slider-active-color", - "--paper-slider-secondary-color", - "--paper-slider-knob-color", - "--paper-slider-knob-start-color", - "--paper-slider-knob-start-border-color", - ]; - - for (var i = 0; i < mixins.length; i++) { - this.customStyle[mixins[i]] = color; + // Update the step if the value fetched is a valid number >= 0 + // (not null, NaN, etc). + this.step = imageData.step >= 0 ? imageData.step : this.step; + if (imageData.wall_time) { + this.wallTime = imageData.wall_time.toString(); } }, - _onStepIndexChanged: function(stepIndex) { - // We manually change the image URL (instead of binding to the image's src attribute) - // because we would like to clear the image URL before setting the src to the new URL. If - // we avoid doing that, the user might be misled into believing that the new image has - // finished loading (and that it looks identical to the previous image). - if (!this._steps.length) { - return; - } - this.$.img.src = ""; - this.$.img.src = this._steps[stepIndex].url; - }, - _computeHasAtLeastOneStep: function(steps) { - return steps.length > 0; - }, - _computeHasMultipleSteps: function(steps) { - return steps.length > 1; - }, - _computeStepValue: function(stepIndex) { - return this._steps[stepIndex].step; - }, - _computeCurrentWallTime: function(stepIndex) { - return this._steps[stepIndex].wall_time.toString(); - }, - _computeMaxStepIndex: function(steps) { - return steps.length - 1; - }, }); </script> </dom-module> diff --git a/tensorflow/tools/tfprof/README.md b/tensorflow/tools/tfprof/README.md index 865a21d6a09..02eca8af6a2 100644 --- a/tensorflow/tools/tfprof/README.md +++ b/tensorflow/tools/tfprof/README.md @@ -152,7 +152,7 @@ tfprof> -min_float_ops 0 -device_regexes .* -order_by name --account_type_regexes Variable +-account_type_regexes Variable,VariableV2 -start_name_regexes .* -trim_name_regexes -show_name_regexes .* diff --git a/tensorflow/tools/tfprof/tfprof_main.cc b/tensorflow/tools/tfprof/tfprof_main.cc index 92e9510ea82..a8ed6e38132 100644 --- a/tensorflow/tools/tfprof/tfprof_main.cc +++ b/tensorflow/tools/tfprof/tfprof_main.cc @@ -75,7 +75,7 @@ int main(int argc, char** argv) { tensorflow::int64 FLAGS_min_float_ops = 0; tensorflow::string FLAGS_device_regexes = ".*"; tensorflow::string FLAGS_order_by = "name"; - tensorflow::string FLAGS_account_type_regexes = "Variable"; + tensorflow::string FLAGS_account_type_regexes = "Variable,VariableV2"; tensorflow::string FLAGS_start_name_regexes = ".*"; tensorflow::string FLAGS_trim_name_regexes = ""; tensorflow::string FLAGS_show_name_regexes = ".*"; diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template index a1fb55632cd..3f087956701 100644 --- a/tools/bazel.rc.template +++ b/tools/bazel.rc.template @@ -27,3 +27,7 @@ run --spawn_strategy=standalone build --genrule_strategy=standalone test --genrule_strategy=standalone run --genrule_strategy=standalone + +build -c opt +test -c opt +run -c opt