From 80a5a3e653f3b10e2680fe2ea9bc511e8801e273 Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Tue, 29 Mar 2016 18:23:11 -0800 Subject: [PATCH] Merge changes from github. Change: 118532471 --- WORKSPACE | 4 + configure | 42 ++- tensorflow/contrib/lookup/lookup_ops.py | 2 +- tensorflow/contrib/skflow/python/__init__.py | 2 +- .../skflow/python/skflow/estimators/base.py | 17 +- .../skflow/python/skflow/ops/dnn_ops.py | 4 +- .../skflow/preprocessing/categorical.py | 2 +- tensorflow/core/BUILD | 206 +++++++++++++- .../common_runtime/direct_session_test.cc | 24 +- .../core/common_runtime/gpu/gpu_init.cc | 2 +- tensorflow/core/distributed_runtime/BUILD | 12 +- tensorflow/core/distributed_runtime/README.md | 6 +- tensorflow/core/distributed_runtime/rpc/BUILD | 16 +- tensorflow/core/kernels/BUILD | 264 ++++++++++++++++-- tensorflow/examples/skflow/README.md | 10 +- tensorflow/examples/skflow/iris.py | 1 + .../examples/skflow/iris_custom_decay_dnn.py | 1 + .../examples/skflow/iris_with_pipeline.py | 2 +- .../tutorials/mnist/fully_connected_feed.py | 3 +- tensorflow/examples/udacity/1_notmnist.ipynb | 11 +- .../examples/udacity/2_fullyconnected.ipynb | 4 +- tensorflow/examples/udacity/5_word2vec.ipynb | 5 +- tensorflow/g3doc/get_started/os_setup.md | 6 +- .../image/cifar10/cifar10_multi_gpu_train.py | 2 +- tensorflow/models/rnn/translate/data_utils.py | 29 +- tensorflow/models/rnn/translate/translate.py | 4 +- tensorflow/python/BUILD | 12 +- tensorflow/python/client/device_lib_test.py | 3 +- tensorflow/python/ops/math_ops.py | 2 +- tensorflow/python/ops/nn_ops.py | 14 +- tensorflow/python/ops/rnn_cell.py | 4 +- .../python/summary/event_accumulator.py | 6 +- .../tf-graph-common/lib/scene/edge.ts | 10 +- tensorflow/tensorflow.bzl | 26 ++ tensorflow/tools/ci_build/builds/benchmark.sh | 155 ++++++++++ .../tools/ci_build/builds/with_the_same_user | 2 +- .../tools/ci_build/ci_parameterized_build.sh | 51 +++- .../ci_build/install/install_pip_packages.sh | 6 + .../docker/notebooks/2_getting_started.ipynb | 10 +- tensorflow/tools/swig/swig.sh | 10 +- tensorflow/tools/test/BUILD | 16 +- tensorflow/tools/test/performance.bzl | 56 ++++ tensorflow/tools/test/run_and_gather_logs.py | 3 +- .../tools/test/run_and_gather_logs_lib.py | 34 ++- tensorflow/workspace.bzl | 8 +- third_party/gpus/crosstool/CROSSTOOL | 4 + .../bin/crosstool_wrapper_driver_is_not_gcc | 8 +- 47 files changed, 975 insertions(+), 146 deletions(-) create mode 100755 tensorflow/tools/ci_build/builds/benchmark.sh create mode 100644 tensorflow/tools/test/performance.bzl diff --git a/WORKSPACE b/WORKSPACE index 981ac77ea46..9684883ffe1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,6 +16,10 @@ load("//tensorflow:workspace.bzl", "tf_workspace") tf_workspace() +# Specify the minimum required bazel version. +load("//tensorflow:tensorflow.bzl", "check_version") +check_version("0.1.4") + # TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT new_git_repository( diff --git a/configure b/configure index 0faf61c67b1..0a7d697c406 100755 --- a/configure +++ b/configure @@ -1,5 +1,7 @@ #!/usr/bin/env bash +DO_NOT_SUBMIT_WARNING="Unofficial setting. DO NOT SUBMIT!!!" + ## Set up python-related environment settings while true; do fromuser="" @@ -22,6 +24,16 @@ while true; do # Retry done +## Find swig path +if [ -z "$SWIG_PATH" ]; then + SWIG_PATH=`type -p swig 2> /dev/null` +fi +if [[ ! -e "$SWIG_PATH" ]]; then + echo "Can't find swig. Ensure swig is in \$PATH or set \$SWIG_PATH." + exit 1 +fi +echo "$SWIG_PATH" > tensorflow/tools/swig/swig_path + # Invoke python_config and set up symlinks to python includes (./util/python/python_config.sh --setup "$PYTHON_BIN_PATH";) || exit -1 @@ -42,6 +54,29 @@ if [ "$TF_NEED_CUDA" == "0" ]; then exit fi +# Set up which gcc nvcc should use as the host compiler +while true; do + fromuser="" + if [ -z "$GCC_HOST_COMPILER_PATH" ]; then + default_gcc_host_compiler_path=$(which gcc) + read -p "Please specify which gcc nvcc should use as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH + fromuser="1" + if [ -z "$GCC_HOST_COMPILER_PATH" ]; then + GCC_HOST_COMPILER_PATH=$default_gcc_host_compiler_path + fi + fi + if [ -e "$GCC_HOST_COMPILER_PATH" ]; then + break + fi + echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + GCC_HOST_COMPILER_PATH="" + # Retry +done + + # Find out where the CUDA toolkit is installed while true; do # Configure the Cuda SDK version to use. @@ -136,6 +171,11 @@ TF_CUDNN_VERSION=$TF_CUDNN_EXT EOF +# Configure the gcc host compiler to use +export WARNING=$DO_NOT_SUBMIT_WARNING +perl -pi -e "s,CPU_COMPILER = \('.*'\),# \$ENV{WARNING}\nCPU_COMPILER = ('$GCC_HOST_COMPILER_PATH'),s" third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc +perl -pi -e "s,GCC_HOST_COMPILER_PATH = \('.*'\),# \$ENV{WARNING}\nGCC_HOST_COMPILER_PATH = ('$GCC_HOST_COMPILER_PATH'),s" third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc + # Configure the Cuda toolkit version to work with. perl -pi -e "s,CUDA_VERSION = \"[0-9\.]*\",CUDA_VERSION = \"$TF_CUDA_EXT\",s" tensorflow/core/platform/default/build_config.bzl perl -pi -e "s,(GetCudaVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDA_EXT\",s" tensorflow/stream_executor/dso_loader.cc @@ -178,7 +218,7 @@ EOF done if [ ! -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then - export WARNING="Unofficial setting. DO NOT"" SUBMIT!!!" + export WARNING=$DO_NOT_SUBMIT_WARNING function CudaGenCodeOpts() { OUTPUT="" for CAPABILITY in $@; do diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index acc3fc26865..3aeab8d86c0 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -391,7 +391,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): ``` Args: - indices: A `int64` `Tensor` with the indices to map to strings. + tensor: A `int64` `Tensor` with the indices to map to strings. mapping: A 1-D string `Tensor` that specifies the strings to map from indices. default_value: The string value to use for out-of-vocabulary indices. diff --git a/tensorflow/contrib/skflow/python/__init__.py b/tensorflow/contrib/skflow/python/__init__.py index 093f79da1ab..f3fc752ca2c 100644 --- a/tensorflow/contrib/skflow/python/__init__.py +++ b/tensorflow/contrib/skflow/python/__init__.py @@ -16,4 +16,4 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from skflow import * +from skflow import * diff --git a/tensorflow/contrib/skflow/python/skflow/estimators/base.py b/tensorflow/contrib/skflow/python/skflow/estimators/base.py index 646d9052695..a8e0a118d39 100644 --- a/tensorflow/contrib/skflow/python/skflow/estimators/base.py +++ b/tensorflow/contrib/skflow/python/skflow/estimators/base.py @@ -268,9 +268,14 @@ class TensorFlowEstimator(BaseEstimator): """ return self.fit(X, y) - def _predict(self, X, axis=-1, batch_size=-1): + def _predict(self, X, axis=-1, batch_size=None): if not self._initialized: raise NotFittedError() + + # Use the batch size for fitting if the user did not specify one. + if batch_size is None: + batch_size = self.batch_size + self._graph.add_to_collection("IS_TRAINING", False) predict_data_feeder = setup_predict_data_feeder( X, batch_size=batch_size) @@ -289,7 +294,7 @@ class TensorFlowEstimator(BaseEstimator): return np.concatenate(preds, axis=0) - def predict(self, X, axis=1, batch_size=-1): + def predict(self, X, axis=1, batch_size=None): """Predict class or regression for X. For a classification model, the predicted class for each sample in X is @@ -302,7 +307,8 @@ class TensorFlowEstimator(BaseEstimator): By default axis 1 (next after batch) is used. Use 2 for sequence predictions. batch_size: If test set is too big, use batch size to split - it into mini batches. By default full dataset is used. + it into mini batches. By default the batch_size member + variable is used. Returns: y: array of shape [n_samples]. The predicted classes or predicted @@ -310,13 +316,14 @@ class TensorFlowEstimator(BaseEstimator): """ return self._predict(X, axis=axis, batch_size=batch_size) - def predict_proba(self, X, batch_size=-1): + def predict_proba(self, X, batch_size=None): """Predict class probability of the input samples X. Args: X: array-like matrix, [n_samples, n_features...] or iterator. batch_size: If test set is too big, use batch size to split - it into mini batches. By default full dataset is used. + it into mini batches. By default the batch_size + member variable is used. Returns: y: array of shape [n_samples, n_classes]. The predicted diff --git a/tensorflow/contrib/skflow/python/skflow/ops/dnn_ops.py b/tensorflow/contrib/skflow/python/skflow/ops/dnn_ops.py index e5b6ea767da..92f2cd2ee09 100644 --- a/tensorflow/contrib/skflow/python/skflow/ops/dnn_ops.py +++ b/tensorflow/contrib/skflow/python/skflow/ops/dnn_ops.py @@ -25,10 +25,10 @@ def dnn(tensor_in, hidden_units, activation=tf.nn.relu, keep_prob=None): """Creates fully connected deep neural network subgraph. Args: - tenson_in: tensor or placeholder for input features. + tensor_in: tensor or placeholder for input features. hidden_units: list of counts of hidden units in each layer. activation: activation function between layers. Can be None. - keep_proba: if not None, will add a dropout layer with given + keep_prob: if not None, will add a dropout layer with given probability. Returns: diff --git a/tensorflow/contrib/skflow/python/skflow/preprocessing/categorical.py b/tensorflow/contrib/skflow/python/skflow/preprocessing/categorical.py index 9adff65c48b..f898694f390 100644 --- a/tensorflow/contrib/skflow/python/skflow/preprocessing/categorical.py +++ b/tensorflow/contrib/skflow/python/skflow/preprocessing/categorical.py @@ -57,7 +57,7 @@ class CategoricalProcessor(object): """Learn a vocabulary dictionary of all categories in X. Args: - raw_documents: numpy matrix or iterable of lists/numpy arrays. + X: numpy matrix or iterable of lists/numpy arrays. unused_y: to match fit format signature of estimators. Returns: diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1b493f2ef5b..8bd358c89e3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -46,6 +46,7 @@ package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_copts") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_tests") load("//tensorflow:tensorflow.bzl", "tf_cuda_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") @@ -1161,13 +1162,18 @@ tf_cc_tests( # TODO(opensource): fix "common_runtime/gpu/*_test.cc", # Run by tests below + "common_runtime/constant_folding_test.cc", + "common_runtime/direct_session_test.cc", + "common_runtime/function_test.cc", "common_runtime/gpu/gpu_allocator_retry_test.cc", "common_runtime/gpu/gpu_bfc_allocator_test.cc", "common_runtime/gpu/gpu_region_allocator_test.cc", + "framework/op_segment_test.cc", + "ops/array_grad_test.cc", + "ops/math_grad_test.cc", ], ), deps = [ - ":all_kernels", ":core", ":core_cpu", ":core_cpu_internal", @@ -1200,10 +1206,10 @@ tf_cc_tests( exclude = [ # Run by tests below "common_runtime/gpu/gpu_allocator_retry_test.cc", + "common_runtime/gpu/gpu_stream_util_test.cc", ], ), deps = [ - ":all_kernels", ":core_cpu", ":core_cpu_internal", ":direct_session", @@ -1221,13 +1227,96 @@ tf_cc_tests( ], ) -tf_cc_tests( +tf_cc_test( + name = "common_runtime/constant_folding_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:bcast_ops", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:matmul_op", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "common_runtime/direct_session_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:dense_update_ops", + "//tensorflow/core/kernels:fifo_queue_op", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:queue_ops", + "//tensorflow/core/kernels:variable_ops", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "common_runtime/function_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:shape_ops", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "common_runtime/gpu/gpu_allocator_retry_test.cc", size = "medium", linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], - tests = ["common_runtime/gpu/gpu_allocator_retry_test.cc"], deps = [ - ":all_kernels", ":core_cpu", ":core_cpu_internal", ":direct_session", @@ -1244,6 +1333,113 @@ tf_cc_tests( ], ) +tf_cc_test( + name = "common_runtime/gpu/gpu_stream_util_test.cc", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":direct_session", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:matmul_op", + ], +) + +tf_cc_test( + name = "framework/op_segment_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "ops/array_grad_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:cwise_op", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "ops/math_grad_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:bcast_ops", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:dynamic_stitch_op", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:reduction_ops", + "//tensorflow/core/kernels:reshape_op", + "//tensorflow/core/kernels:sequence_ops", + "//tensorflow/core/kernels:shape_ops", + "//tensorflow/core/kernels:tile_ops", + "//third_party/eigen3", + ], +) + # Test data filegroup( name = "image_testdata", diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 59f0dd3fff2..1495b832504 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -151,7 +151,7 @@ TEST_F(DirectSessionMinusAXTest, TestConcurrency) { std::vector outputs; // Run the graph Status s = session->Run(inputs, output_names, {}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); auto mat = outputs[0].matrix(); EXPECT_FLOAT_EQ(3.0, mat(0, 0)); @@ -188,7 +188,7 @@ TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) { std::vector outputs; // Run the graph Status s = session->Run(inputs, output_names, {}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); auto mat = outputs[0].matrix(); EXPECT_FLOAT_EQ(3.0, mat(0, 0)); @@ -358,7 +358,7 @@ TEST(DirectSessionTest, MultipleFeedTest) { Status s = session->Run( {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(2, outputs.size()); ASSERT_EQ(1.0, outputs[0].flat()(0)); ASSERT_EQ(2.0, outputs[1].flat()(0)); @@ -366,7 +366,7 @@ TEST(DirectSessionTest, MultipleFeedTest) { s = session->Run( {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(2, outputs.size()); ASSERT_EQ(2.0, outputs[0].flat()(0)); ASSERT_EQ(1.0, outputs[1].flat()(0)); @@ -381,7 +381,7 @@ TEST(DirectSessionTest, MultipleFeedTest) { {{first_const->name(), value_11}, {second_const->name(), value_22}}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(2, outputs.size()); ASSERT_EQ(11.0, outputs[0].flat()(0)); ASSERT_EQ(22.0, outputs[1].flat()(0)); @@ -391,7 +391,7 @@ TEST(DirectSessionTest, MultipleFeedTest) { {{second_const->name(), value_22}, {first_const->name(), value_11}}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(2, outputs.size()); ASSERT_EQ(11.0, outputs[0].flat()(0)); ASSERT_EQ(22.0, outputs[1].flat()(0)); @@ -462,7 +462,7 @@ TEST(DirectSessionTest, PartialRunTest) { {first_identity->name() + ":0", second_identity->name() + ":0", third_identity->name() + ":0"}, {}, &handle); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); Tensor value_11(DT_FLOAT, TensorShape({})); value_11.scalar()() = 11.0; @@ -472,7 +472,7 @@ TEST(DirectSessionTest, PartialRunTest) { // Feed first_const, fetch first_identity s = session->PRun(handle, {{first_const->name(), value_11}}, {first_identity->name() + ":0"}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); ASSERT_EQ(11.0, outputs[0].flat()(0)); @@ -481,7 +481,7 @@ TEST(DirectSessionTest, PartialRunTest) { handle, {{second_const->name(), value_22}}, {second_identity->name() + ":0", third_identity->name() + ":0"}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(2, outputs.size()); ASSERT_EQ(22.0, outputs[0].flat()(0)); ASSERT_EQ(11.0 + 22.0, outputs[1].flat()(0)); @@ -515,7 +515,7 @@ TEST(DirectSessionTest, PartialRunMissingFeed) { string handle; Status s = session->PRunSetup({first_const->name(), second_const->name()}, {third_identity->name() + ":0"}, {}, &handle); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); // Feed first_const, fetch third_identity Tensor value_11(DT_FLOAT, TensorShape({})); @@ -548,7 +548,7 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) { string handle; Status s = session->PRunSetup({switch_node->name() + ":1"}, {fourth_identity->name() + ":0"}, {}, &handle); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); // Fetch fourth_identity without feeds. s = session->PRun(handle, {}, {fourth_identity->name() + ":0"}, &outputs); @@ -559,7 +559,7 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) { // Feed switch_node:1 and fetch fourth_identity. s = session->PRun(handle, {{switch_node->name() + ":1", bool_value}}, {fourth_identity->name() + ":0"}, &outputs); - ASSERT_TRUE(s.ok()); + TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); ASSERT_EQ(true, outputs[0].flat()(0)); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.cc b/tensorflow/core/common_runtime/gpu/gpu_init.cc index 96816fd9cbe..ddb007115f0 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_init.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_init.cc @@ -77,7 +77,7 @@ static void InitGPU() { int dev_count = platform->VisibleDeviceCount(); - if (dev_count == 0) { + if (dev_count <= 0) { LOG(INFO) << "No GPU devices available on machine."; return; } diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index b494854accd..18428aa6e79 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -95,7 +95,6 @@ cc_library( ":worker_interface", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:worker_proto_cc", ], ) @@ -125,7 +124,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:master_proto_cc", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:worker_proto_cc", ], ) @@ -205,7 +203,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/core:tensorflow_opensource", ], ) @@ -227,7 +224,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:worker_proto_cc", ], ) @@ -240,7 +236,6 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:tensorflow_opensource", ], ) @@ -306,7 +301,6 @@ tf_cc_tests( "//tensorflow/core:master_proto_cc", "//tensorflow/core:master_service_proto_cc", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -314,6 +308,11 @@ tf_cc_tests( "//tensorflow/core/distributed_runtime/rpc:grpc_testlib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:dense_update_ops", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:variable_ops", ], ) @@ -339,7 +338,6 @@ tf_cc_tests( "//tensorflow/core:master_proto_cc", "//tensorflow/core:master_service_proto_cc", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/core/distributed_runtime/README.md b/tensorflow/core/distributed_runtime/README.md index c7bd8164e3e..ab1771e2942 100644 --- a/tensorflow/core/distributed_runtime/README.md +++ b/tensorflow/core/distributed_runtime/README.md @@ -5,6 +5,6 @@ distributed TensorFlow runtime, using [gRPC](http://grpc.io) for inter-process communication. To learn how to use the distributed runtime to create a TensorFlow cluster, -see the "Distributed TensorFlow" How To, which is available both [in this -repository](https://www.tensorflow.org/code/tensorflow/g3doc/how_tos/distributed/index.md) and [on the TensorFlow website] -(https://www.tensorflow.org/how_tos/distributed/index.html). +see the "Distributed TensorFlow" How To, which is available [in this +repository](../../g3doc/how_tos/distributed/index.md), and will be available +on the TensorFlow website after the next version is released. diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 32ca16b800e..beddf03ffa4 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -143,7 +143,6 @@ cc_library( "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:worker_proto_cc", "//tensorflow/core:worker_service_proto_cc", "//tensorflow/core/distributed_runtime:graph_mgr", @@ -197,7 +196,6 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", "//tensorflow/core/distributed_runtime:process_util", "//tensorflow/core/distributed_runtime:worker_cache", @@ -258,7 +256,6 @@ tf_cuda_library( srcs = ["grpc_testlib_ops.cc"], linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ - "//tensorflow/core:all_kernels", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -279,6 +276,13 @@ cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:dense_update_ops", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:reduction_ops", + "//tensorflow/core/kernels:variable_ops", ], ) @@ -297,7 +301,6 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:test", ], alwayslink = 1, @@ -316,7 +319,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:master_interface", ], @@ -373,5 +375,9 @@ tf_cc_tests( "//tensorflow/core:testlib", "//tensorflow/core/distributed_runtime:process_util", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:dense_update_ops", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:variable_ops", ], ) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 6337d39bcf1..a4ec276ac15 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -301,21 +301,12 @@ tf_kernel_libraries( ], ) -tf_cc_tests( +tf_cc_test( + name = "concat_op_test", size = "small", linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking - tests = [ - "concat_op_test", - "constant_op_test", - "gather_nd_op_test", - "gather_op_test", - "identity_op_test", - "reverse_op_test", - "slice_op_test", - "unique_op_test", - ], deps = [ - ":array", + ":concat_op", ":ops_testutil", ":ops_util", "//tensorflow/core:core_cpu", @@ -329,6 +320,120 @@ tf_cc_tests( ], ) +tf_cc_test( + name = "constant_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":constant_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "gather_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":gather_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "identity_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":identity_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "reverse_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":ops_testutil", + ":ops_util", + ":reverse_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "slice_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":ops_testutil", + ":ops_util", + ":slice_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "unique_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":ops_testutil", + ":ops_util", + ":unique_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "transpose_functor", srcs = ["transpose_functor_cpu.cc"], @@ -756,20 +861,12 @@ tf_kernel_libraries( ], ) -tf_cc_tests( +tf_cc_test( + name = "cast_op_test", size = "small", linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking - tests = [ - "cast_op_test", - "cross_op_test", - "cwise_ops_test", - "matmul_op_test", - "reduction_ops_test", - "segment_reduction_ops_test", - "sparse_matmul_op_test", - ], deps = [ - ":math", + ":cast_op", ":ops_testutil", ":ops_util", "//tensorflow/core:core_cpu", @@ -783,21 +880,136 @@ tf_cc_tests( ], ) +tf_cc_test( + name = "cross_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":cross_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "cwise_ops_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":cwise_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "matmul_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":matmul_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "reduction_ops_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":ops_testutil", + ":ops_util", + ":reduction_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "segment_reduction_ops_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":ops_testutil", + ":ops_util", + ":segment_reduction_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "sparse_matmul_op_test", + size = "small", + linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking + deps = [ + ":ops_testutil", + ":ops_util", + ":sparse_matmul_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "immutable_constant_op_test", linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking deps = [ ":array", ":immutable_constant_op", - ":math", + ":matmul_op", ":ops_testutil", ":ops_util", + ":random_shuffle_op", "//tensorflow/cc:cc_ops", "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session", "//tensorflow/core:framework", "//tensorflow/core:lib", - # TODO(irving): Don't depend on all of TensorFlow for this test - "//tensorflow/core:tensorflow", + "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/examples/skflow/README.md b/tensorflow/examples/skflow/README.md index 24b447cd7a4..8bb2739e6c9 100644 --- a/tensorflow/examples/skflow/README.md +++ b/tensorflow/examples/skflow/README.md @@ -1,15 +1,16 @@ # Examples of Using skflow -Scikit Flow is high level API that allows to create, +Scikit Flow is high level API that allows to create, train and use deep learning models easily with well known Scikit Learn API. -To run this exampels you need to have `scikit learn` library installed (`sudo pip install sklearn`). -Some examples use `pandas` library for data processing (`sudo pip install pandas`). +To run these examples, you need to have `scikit learn` library installed (`sudo pip install sklearn`). +Some examples use the `pandas` library for data processing (`sudo pip install pandas`). * [Deep Neural Network Regression with Boston Data](boston.py) * [Convolutional Neural Networks with Digits Data](digits.py) * [Deep Neural Network Classification with Iris Data](iris.py) +* [Grid search and Deep Neural Network Classification](iris_gridsearch_cv.py) * [Deep Neural Network with Customized Decay Function](iris_custom_decay_dnn.py) * [Building A Custom Model](iris_custom_model.py) * [Accessing Weights and Biases in A Custom Model](mnist_weights.py) @@ -30,7 +31,7 @@ Some examples use `pandas` library for data processing (`sudo pip install pandas ## Text classification -* [Text Classification Using Recurrent Neural Networks on Words](text_classification.py) +* [Text Classification Using Recurrent Neural Networks on Words](text_classification.py) (See also [Simplified Version Using Built-in RNN Model](text_classification_builtin_rnn_model.py) using built-in parameters) * [Text Classification Using Convolutional Neural Networks on Words](text_classification_cnn.py) * [Text Classification Using Recurrent Neural Networks on Characters](text_classification_character_rnn.py) @@ -46,4 +47,3 @@ Some examples use `pandas` library for data processing (`sudo pip install pandas * [Character level neural language translation](neural_translation.py) * [Word level neural language translation](neural_translation_word.py) - diff --git a/tensorflow/examples/skflow/iris.py b/tensorflow/examples/skflow/iris.py index ee330e31f40..5b72195f40f 100644 --- a/tensorflow/examples/skflow/iris.py +++ b/tensorflow/examples/skflow/iris.py @@ -32,3 +32,4 @@ classifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], classifier.fit(X_train, y_train) score = metrics.accuracy_score(y_test, classifier.predict(X_test)) print('Accuracy: {0:f}'.format(score)) + diff --git a/tensorflow/examples/skflow/iris_custom_decay_dnn.py b/tensorflow/examples/skflow/iris_custom_decay_dnn.py index 9b0a60d3373..f9c172725d9 100644 --- a/tensorflow/examples/skflow/iris_custom_decay_dnn.py +++ b/tensorflow/examples/skflow/iris_custom_decay_dnn.py @@ -17,6 +17,7 @@ from __future__ import print_function from sklearn import datasets, metrics from sklearn.cross_validation import train_test_split + import tensorflow as tf from tensorflow.contrib import skflow diff --git a/tensorflow/examples/skflow/iris_with_pipeline.py b/tensorflow/examples/skflow/iris_with_pipeline.py index 08c5b2fe54b..f6408f84a8a 100644 --- a/tensorflow/examples/skflow/iris_with_pipeline.py +++ b/tensorflow/examples/skflow/iris_with_pipeline.py @@ -32,7 +32,7 @@ scaler = StandardScaler() # DNN classifier DNNclassifier = skflow.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3, steps=200) -pipeline = Pipeline([('scaler', scaler, ('DNNclassifier', DNNclassifier)]) +pipeline = Pipeline([('scaler', scaler), ('DNNclassifier', DNNclassifier)]) pipeline.fit(X_train, y_train) diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py index eda1ac5b596..a67055f88f4 100644 --- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py +++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py @@ -19,10 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path import time -import numpy from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf @@ -192,6 +190,7 @@ def run_training(): # Update the events file. summary_str = sess.run(summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) + summary_writer.flush() # Save a checkpoint and evaluate the model periodically. if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: diff --git a/tensorflow/examples/udacity/1_notmnist.ipynb b/tensorflow/examples/udacity/1_notmnist.ipynb index 9d864ccd374..22654458154 100644 --- a/tensorflow/examples/udacity/1_notmnist.ipynb +++ b/tensorflow/examples/udacity/1_notmnist.ipynb @@ -55,7 +55,10 @@ "from scipy import ndimage\n", "from sklearn.linear_model import LogisticRegression\n", "from six.moves.urllib.request import urlretrieve\n", - "from six.moves import cPickle as pickle" + "from six.moves import cPickle as pickle\n", + "\n", + "# Config the matlotlib backend as plotting inline in IPython\n", + "%matplotlib inline" ], "outputs": [], "execution_count": 0 @@ -295,9 +298,8 @@ " image_files = os.listdir(folder)\n", " dataset = np.ndarray(shape=(len(image_files), image_size, image_size),\n", " dtype=np.float32)\n", - " image_index = 0\n", " print(folder)\n", - " for image in os.listdir(folder):\n", + " for image_index, image in enumerate(image_files):\n", " image_file = os.path.join(folder, image)\n", " try:\n", " image_data = (ndimage.imread(image_file).astype(float) - \n", @@ -305,11 +307,10 @@ " if image_data.shape != (image_size, image_size):\n", " raise Exception('Unexpected image shape: %s' % str(image_data.shape))\n", " dataset[image_index, :, :] = image_data\n", - " image_index += 1\n", " except IOError as e:\n", " print('Could not read:', image_file, ':', e, '- it\\'s ok, skipping.')\n", " \n", - " num_images = image_index\n", + " num_images = image_index + 1\n", " dataset = dataset[0:num_images, :, :]\n", " if num_images < min_num_images:\n", " raise Exception('Many fewer images than expected: %d < %d' %\n", diff --git a/tensorflow/examples/udacity/2_fullyconnected.ipynb b/tensorflow/examples/udacity/2_fullyconnected.ipynb index c8815f631b5..588b581a69b 100644 --- a/tensorflow/examples/udacity/2_fullyconnected.ipynb +++ b/tensorflow/examples/udacity/2_fullyconnected.ipynb @@ -410,7 +410,7 @@ "source": [ "Let's now switch to stochastic gradient descent training instead, which is much faster.\n", "\n", - "The graph will be similar, except that instead of holding all the training data into a constant node, we create a `Placeholder` node which will be fed actual data at every call of `sesion.run()`." + "The graph will be similar, except that instead of holding all the training data into a constant node, we create a `Placeholder` node which will be fed actual data at every call of `session.run()`." ] }, { @@ -577,7 +577,7 @@ "Problem\n", "-------\n", "\n", - "Turn the logistic regression example with SGD into a 1-hidden layer neural network with rectified linear units (nn.relu()) and 1024 hidden nodes. This model should improve your validation / test accuracy.\n", + "Turn the logistic regression example with SGD into a 1-hidden layer neural network with rectified linear units [nn.relu()](https://www.tensorflow.org/versions/r0.7/api_docs/python/nn.html#relu) and 1024 hidden nodes. This model should improve your validation / test accuracy.\n", "\n", "---" ] diff --git a/tensorflow/examples/udacity/5_word2vec.ipynb b/tensorflow/examples/udacity/5_word2vec.ipynb index 94ba37ee13e..62dbec4e114 100644 --- a/tensorflow/examples/udacity/5_word2vec.ipynb +++ b/tensorflow/examples/udacity/5_word2vec.ipynb @@ -43,6 +43,7 @@ "source": [ "# These are all the modules we'll be using later. Make sure you can import them\n", "# before proceeding further.\n", + "%matplotlib inline\n", "from __future__ import print_function\n", "import collections\n", "import math\n", @@ -521,12 +522,12 @@ " # note that this is expensive (~20% slowdown if computed every 500 steps)\n", " if step % 10000 == 0:\n", " sim = similarity.eval()\n", - " for i in xrange(valid_size):\n", + " for i in range(valid_size):\n", " valid_word = reverse_dictionary[valid_examples[i]]\n", " top_k = 8 # number of nearest neighbors\n", " nearest = (-sim[i, :]).argsort()[1:top_k+1]\n", " log = 'Nearest to %s:' % valid_word\n", - " for k in xrange(top_k):\n", + " for k in range(top_k):\n", " close_word = reverse_dictionary[nearest[k]]\n", " log = '%s %s,' % (log, close_word)\n", " print(log)\n", diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 6e0e409e0f3..3323210b832 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -531,6 +531,10 @@ directory: ```bash bazel build -c opt //tensorflow/tools/pip_package:build_pip_package + +# To build with GPU support: +bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package + mkdir _python_build cd _python_build ln -s ../bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/* . @@ -547,7 +551,7 @@ rules. Starting from the root of your source tree, run: -```python +```bash $ cd tensorflow/models/image/mnist $ python convolutional.py Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes. diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py index da4565ff6cf..8a5ec3cb22c 100644 --- a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py +++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py @@ -200,7 +200,7 @@ def train(): # Add histograms for gradients. for grad, var in grads: - if grad: + if grad is not None: summaries.append( tf.histogram_summary(var.op.name + '/gradients', grad)) diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py index 48da4f065ca..001182bbd59 100644 --- a/tensorflow/models/rnn/translate/data_utils.py +++ b/tensorflow/models/rnn/translate/data_utils.py @@ -28,10 +28,10 @@ from six.moves import urllib from tensorflow.python.platform import gfile # Special vocabulary symbols - we always put them at the start. -_PAD = "_PAD" -_GO = "_GO" -_EOS = "_EOS" -_UNK = "_UNK" +_PAD = b"_PAD" +_GO = b"_GO" +_EOS = b"_EOS" +_UNK = b"_UNK" _START_VOCAB = [_PAD, _GO, _EOS, _UNK] PAD_ID = 0 @@ -40,8 +40,8 @@ EOS_ID = 2 UNK_ID = 3 # Regular expressions used to tokenize. -_WORD_SPLIT = re.compile("([.,!?\"':;)(])") -_DIGIT_RE = re.compile(r"\d") +_WORD_SPLIT = re.compile(b"([.,!?\"':;)(])") +_DIGIT_RE = re.compile(br"\d") # URLs for WMT data. _WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar" @@ -131,7 +131,7 @@ def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, if not gfile.Exists(vocabulary_path): print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) vocab = {} - with gfile.GFile(data_path, mode="r") as f: + with gfile.GFile(data_path, mode="rb") as f: counter = 0 for line in f: counter += 1 @@ -139,7 +139,7 @@ def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, print(" processing line %d" % counter) tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) for w in tokens: - word = re.sub(_DIGIT_RE, "0", w) if normalize_digits else w + word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w if word in vocab: vocab[word] += 1 else: @@ -147,9 +147,9 @@ def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) if len(vocab_list) > max_vocabulary_size: vocab_list = vocab_list[:max_vocabulary_size] - with gfile.GFile(vocabulary_path, mode="w") as vocab_file: + with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: for w in vocab_list: - vocab_file.write(w + "\n") + vocab_file.write(w + b"\n") def initialize_vocabulary(vocabulary_path): @@ -173,7 +173,7 @@ def initialize_vocabulary(vocabulary_path): """ if gfile.Exists(vocabulary_path): rev_vocab = [] - with gfile.GFile(vocabulary_path, mode="r") as f: + with gfile.GFile(vocabulary_path, mode="rb") as f: rev_vocab.extend(f.readlines()) rev_vocab = [line.strip() for line in rev_vocab] vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) @@ -191,7 +191,7 @@ def sentence_to_token_ids(sentence, vocabulary, "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. Args: - sentence: a string, the sentence to convert to token-ids. + sentence: the sentence in bytes format to convert to token-ids. vocabulary: a dictionary mapping tokens to integers. tokenizer: a function to use to tokenize each sentence; if None, basic_tokenizer will be used. @@ -200,6 +200,7 @@ def sentence_to_token_ids(sentence, vocabulary, Returns: a list of integers, the token-ids for the sentence. """ + if tokenizer: words = tokenizer(sentence) else: @@ -207,7 +208,7 @@ def sentence_to_token_ids(sentence, vocabulary, if not normalize_digits: return [vocabulary.get(w, UNK_ID) for w in words] # Normalize digits by 0 before looking words up in the vocabulary. - return [vocabulary.get(re.sub(_DIGIT_RE, "0", w), UNK_ID) for w in words] + return [vocabulary.get(re.sub(_DIGIT_RE, b"0", w), UNK_ID) for w in words] def data_to_token_ids(data_path, target_path, vocabulary_path, @@ -229,7 +230,7 @@ def data_to_token_ids(data_path, target_path, vocabulary_path, if not gfile.Exists(target_path): print("Tokenizing data in %s" % data_path) vocab, _ = initialize_vocabulary(vocabulary_path) - with gfile.GFile(data_path, mode="r") as data_file: + with gfile.GFile(data_path, mode="rb") as data_file: with gfile.GFile(target_path, mode="w") as tokens_file: counter = 0 for line in data_file: diff --git a/tensorflow/models/rnn/translate/translate.py b/tensorflow/models/rnn/translate/translate.py index f6b07230b4f..a0691b5b26b 100644 --- a/tensorflow/models/rnn/translate/translate.py +++ b/tensorflow/models/rnn/translate/translate.py @@ -225,7 +225,7 @@ def decode(): sentence = sys.stdin.readline() while sentence: # Get token-ids for the input sentence. - token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab) + token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab) # Which bucket does it belong to? bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)]) @@ -241,7 +241,7 @@ def decode(): if data_utils.EOS_ID in outputs: outputs = outputs[:outputs.index(data_utils.EOS_ID)] # Print out French sentence corresponding to outputs. - print(" ".join([rev_fr_vocab[output] for output in outputs])) + print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs])) print("> ", end="") sys.stdout.flush() sentence = sys.stdin.readline() diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a1ca9815270..aa1b1a6e654 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -289,7 +289,7 @@ py_library( cuda_py_tests( name = "framework_function_test", - size = "small", + size = "medium", srcs = ["framework/function_test.py"], additional_deps = [ ":functional_ops_lib", @@ -1078,6 +1078,8 @@ py_library( ) medium_kernel_test_list = glob([ + "kernel_tests/concat_op_test.py", + "kernel_tests/division_future_test.py", "kernel_tests/fft_ops_test.py", "kernel_tests/rnn_test.py", "kernel_tests/scatter_ops_test.py", @@ -1087,6 +1089,7 @@ medium_kernel_test_list = glob([ sharded_kernel_test_list = glob([ "kernel_tests/cwise_ops_test.py", + "kernel_tests/embedding_ops_test.py", "kernel_tests/linalg_grad_test.py", ]) @@ -1161,11 +1164,18 @@ cuda_py_tests( ["ops/*_test.py"], exclude = [ "ops/image_ops_test.py", + "ops/nn_test.py", "ops/op_def_library_test.py", ], ), ) +cuda_py_tests( + name = "medium_op_tests", + size = "medium", + srcs = ["ops/nn_test.py"], +) + cuda_py_tests( name = "kernel_tests", size = "small", diff --git a/tensorflow/python/client/device_lib_test.py b/tensorflow/python/client/device_lib_test.py index ee028573aac..a455af44d08 100644 --- a/tensorflow/python/client/device_lib_test.py +++ b/tensorflow/python/client/device_lib_test.py @@ -27,7 +27,8 @@ from tensorflow.python.platform import googletest class DeviceLibTest(test_util.TensorFlowTestCase): - def testListLocalDevices(self): + # TODO(ebrevdo): fix python3 compatibility: b/27727661 + def _testListLocalDevices(self): devices = device_lib.list_local_devices() self.assertGreater(len(devices), 0) self.assertEqual(devices[0].device_type, "CPU") diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 25eb7c670f9..c6db6c9c05a 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -952,7 +952,7 @@ def trace(x, name=None): ``` Args: - input_tensor: 2-D tensor. + x: 2-D tensor. name: A name for the operation (optional). Returns: diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index e7140816bbc..331c1bb7135 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -195,10 +195,8 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None): can be a dog or a truck, but not both. **NOTE:** While the classes are mutually exclusive, their probabilities - need not be. All that is required is that each row of `labels` is - a valid probability distribution. If using exclusive `labels` - (wherein one and only one class is true at a time), see - `sparse_softmax_cross_entropy_with_logits`. + need not be. If using exclusive `labels` (wherein one and only one class is + true at a time), see `sparse_softmax_cross_entropy_with_logits`. **WARNING:** This op expects unscaled logits, since it performs a `softmax` on `logits` internally for efficiency. Do not call this op with the @@ -209,7 +207,9 @@ def softmax_cross_entropy_with_logits(logits, labels, name=None): Args: logits: Unscaled log probabilities. - labels: Each row `labels[i]` must be a valid probability distribution. + labels: Each row `labels[i]` must be a valid probability distribution or + all zeros. If all zeros, the corresponding loss will be `0`, regardless + of the contents of `logits[i]`. name: A name for the operation (optional). Returns: @@ -249,7 +249,9 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): Args: logits: Unscaled log probabilities. - labels: Each entry `labels[i]` must be an index in `[0, num_classes)`. + labels: Each entry `labels[i]` must be an index in `[0, num_classes)` or + `-1`. If `-1`, the corresponding loss will be `0`, regardless + of the contents of `logits[i]`. name: A name for the operation (optional). Returns: diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py index ebdfdc113b1..e33e2964516 100644 --- a/tensorflow/python/ops/rnn_cell.py +++ b/tensorflow/python/ops/rnn_cell.py @@ -208,7 +208,7 @@ class BasicLSTMCell(RNNCell): new_c = c * sigmoid(f + self._forget_bias) + sigmoid(i) * tanh(j) new_h = tanh(new_c) * sigmoid(o) - return new_h, array_ops.concat(1, [new_c, new_h]) + return new_h, array_ops.concat(1, [new_c, new_h]) def _get_concat_variable(name, shape, dtype, num_shards): @@ -344,7 +344,7 @@ class LSTMCell(RNNCell): actual_input_size = inputs.get_shape().as_list()[1] if self._input_size and self._input_size != actual_input_size: raise ValueError("Actual input size not same as specified: %d vs %d." % - actual_input_size, self._input_size) + (actual_input_size, self._input_size)) with vs.variable_scope(scope or type(self).__name__, initializer=self._initializer): # "LSTMCell" concat_w = _get_concat_variable( diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py index a7cc3ba603a..e36dc6e43f3 100644 --- a/tensorflow/python/summary/event_accumulator.py +++ b/tensorflow/python/summary/event_accumulator.py @@ -197,14 +197,14 @@ class EventAccumulator(object): ## Process the event if event.HasField('graph_def'): if self._graph is not None: - logging.warn(('Found more than one graph event per run.' - 'Overwritting the graph with the newest event.')) + logging.warn(('Found more than one graph event per run. ' + 'Overwriting the graph with the newest event.')) self._graph = event.graph_def elif event.HasField('tagged_run_metadata'): tag = event.tagged_run_metadata.tag if tag in self._tagged_metadata: logging.warn('Found more than one "run metadata" event with tag ' + - tag + '. Overwritting it with the newest event.') + tag + '. Overwriting it with the newest event.') self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata elif event.HasField('summary'): for value in event.summary.value: diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts index 4fa46061002..56ff83ce584 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts @@ -39,6 +39,9 @@ let arrowheadMap = d3.scale.quantize() .domain([MIN_EDGE_WIDTH, MAX_EDGE_WIDTH]) .range(["small", "medium", "large", "xlarge"]); +/** Minimum stroke width to put edge labels in the middle of edges */ +const CENTER_EDGE_LABEL_MIN_STROKE_WIDTH = 2.5; + export type EdgeData = {v: string, w: string, label: render.RenderMetaedgeInfo}; export function getEdgeKey(edgeObj: EdgeData) { @@ -254,11 +257,16 @@ export function appendEdge(edgeGroup, d: EdgeData, // We have no information to show on this edge. return; } + + // Put edge label in the middle of edge only if the edge is thick enough. + let baseline = strokeWidth > CENTER_EDGE_LABEL_MIN_STROKE_WIDTH ? + "central" : "text-after-edge"; + edgeGroup.append("text").append("textPath").attr({ "xlink:href": "#" + pathId, "startOffset": "50%", "text-anchor": "middle", - "dominant-baseline": "central" + "dominant-baseline": baseline }).text(labelForEdge); }; diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 09f29bd0e00..27b66bd7c62 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1,5 +1,31 @@ # -*- Python -*- +# Parse the bazel version string from `native.bazel_version`. +def _parse_bazel_version(bazel_version): + # Remove commit from version. + version = bazel_version.split(" ", 1)[0] + + # Split into (release, date) parts and only return the release + # as a tuple of integers. + parts = version.split('-', 1) + + # Turn "release" into a tuple of integers + version_tuple = () + for number in parts[0].split('.'): + version_tuple += (int(number),) + return version_tuple + + +# Check that a specific bazel version is being used. +def check_version(bazel_version): + if "bazel_version" in dir(native): + current_bazel_version = _parse_bazel_version(native.bazel_version) + minimum_bazel_version = _parse_bazel_version(bazel_version) + if minimum_bazel_version > current_bazel_version: + fail("\nCurrent Bazel version is {}, expected at least {}\n".format( + native.bazel_version, bazel_version)) + pass + # Return the options to use for a C++ library or binary build. # Uses the ":optmode" config_setting to pick the options. diff --git a/tensorflow/tools/ci_build/builds/benchmark.sh b/tensorflow/tools/ci_build/builds/benchmark.sh new file mode 100755 index 00000000000..b78fa0e92d3 --- /dev/null +++ b/tensorflow/tools/ci_build/builds/benchmark.sh @@ -0,0 +1,155 @@ +#!/usr/bin/env bash +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Runs benchmark tests. +# After the completion of each benchmark test, the script calls a hook binary, +# specified with the environment variable TF_BUILD_BENCHMARK_HOOK, to handle +# the test log file. This hook binary may perform operations such as entering +# the test results into a database. +# +# Usage: benchmark [-c opt] +# Option flags +# -c opt: Use optimized C++ build ("-c opt") +# +# This script obeys the following environmental variables: +# TF_BUILD_BENCHMARK_HOOK: +# Path to a binary / script that will handle the test log and other related +# info after the completion of each benchmark test. + +set -u + +echo "" +echo "====== Benchmark tests start ======" + +# Process input arguments +OPT_FLAG="" +while getopts c: flag; do + case ${flag} in + c) + if [[ ! -z "{OPTARG}" ]]; then + OPT_FLAG="${OPT_FLAG} -c ${OPTARG}" + fi + ;; + esac +done + +BENCHMARK_HOOK=${TF_BUILD_BENCHMARK_HOOK:-""} + + +BENCHMARK_TAG="benchmark-test" +BENCHMARK_TESTS=$(bazel query \ + 'attr("tags", "'"${BENCHMARK_TAG}"'", //tensorflow/...)') + +if [[ -z "${BENCHMARK_TESTS}" ]]; then + echo "ERROR: Cannot find any benchmark tests with the tag "\ +"\"${BENCHMARK_TAG}\"" + exit 1 +fi + +N_TESTS=$(echo ${BENCHMARK_TESTS} | wc -w) + +echo "Discovered ${N_TESTS} benchmark test(s) with the tag \"${BENCHMARK_TAG}\":" +echo ${BENCHMARK_TESTS} +echo "" + +PASS_COUNTER=0 +FAIL_COUNTER=0 +FAILED_TESTS="" +COUNTER=0 + +# Iterate through the benchmark tests +for BENCHMARK_TEST in ${BENCHMARK_TESTS}; do + ((COUNTER++)) + + echo "" + echo "Running benchmark test (${COUNTER} / ${N_TESTS}): ${BENCHMARK_TEST}" + + bazel test ${OPT_FLAG} --cache_test_results=no "${BENCHMARK_TEST}" + TEST_RESULT=$? + + # Hook for database + # Verify that test log exists + TEST_LOG=$(echo ${BENCHMARK_TEST} | sed -e 's/:/\//g') + TEST_LOG="bazel-testlogs/${TEST_LOG}/test.log" + if [[ -f "${TEST_LOG}" ]]; then + echo "Benchmark ${BENCHMARK_TEST} done: log @ ${TEST_LOG}" + + # Call database hook if exists + if [[ ! -z "${BENCHMARK_HOOK}" ]]; then + # Assume that the hook binary/script takes two arguments: + # Argument 1: Compilation flags such as "-c opt" as a whole + # Argument 2: Test log containing the serialized TestResults proto + + echo "Calling database hook: ${TF_BUILD_BENCHMARK_LOG_HOOK} "\ +"${OPT_FLAG} ${TEST_LOG}" + + ${TF_BUILD_BENCHMARK_LOG_HOOK} "${OPT_FLAG}" "${TEST_LOG}" + else + echo "WARNING: No hook binary is specified to handle test log ${TEST_LOG}" + fi + else + # Mark as failure if the test log file cannot be found + TEST_RESULT=2 + + echo "ERROR: Cannot find log file from benchmark ${BENCHMARK_TEST} @ "\ +"${TEST_LOG}" + fi + + echo "" + if [[ ${TEST_RESULT} -eq 0 ]]; then + ((PASS_COUNTER++)) + + echo "Benchmark test PASSED: ${BENCHMARK_TEST}" + else + ((FAIL_COUNTER++)) + + FAILED_TESTS="${FAILED_TESTS} ${BENCHMARK_TEST}" + + echo "Benchmark test FAILED: ${BENCHMARK_TEST}" + + if [[ -f "${TEST_LOG}" ]]; then + echo "============== BEGINS failure log content ==============" + cat ${TEST_LOG} >&2 + echo "============== ENDS failure log content ==============" + echo "" + fi + fi + +done + +# Summarize test results +echo "" +echo "${N_TESTS} Benchmark test(s):" \ + "${PASS_COUNTER} passed;" \ + "${FAIL_COUNTER} failed" + +if [[ ${FAIL_COUNTER} -eq 0 ]]; then + echo "" + echo "Benchmark tests SUCCEEDED" + + exit 0 +else + echo "FAILED benchmark test(s):" + FAIL_COUNTER=0 + for TEST_NAME in ${FAILED_TESTS}; do + echo " ${TEST_NAME}" + ((FAIL_COUNTER++)) + done + + echo "" + echo "Benchmark tests FAILED" + exit 1 +fi diff --git a/tensorflow/tools/ci_build/builds/with_the_same_user b/tensorflow/tools/ci_build/builds/with_the_same_user index e723974853b..2f98d05bc1e 100755 --- a/tensorflow/tools/ci_build/builds/with_the_same_user +++ b/tensorflow/tools/ci_build/builds/with_the_same_user @@ -34,7 +34,7 @@ getent passwd "${CI_BUILD_UID}" || adduser --gid "${CI_BUILD_GID}" --uid "${CI_B usermod -a -G sudo "${CI_BUILD_USER}" echo "${CI_BUILD_USER} ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-nopasswd-sudo -if [ -e /root/.bazelrc]; then +if [ -e /root/.bazelrc ]; then cp /root/.bazelrc "${CI_BUILD_HOME}/.bazelrc" chown "${CI_BUILD_UID}:${CI_BUILD_GID}" "${CI_BUILD_HOME}/.bazelrc" fi diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 9b7e5abd621..aa4c3bf2f6c 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -54,6 +54,10 @@ # tutorials tests (Applicable only if TF_BUILD_IS_PIP is # PIP or BOTH). # See builds/test_tutorials.sh +# TF_BUILD_RUN_BENCHMARKS: +# If set to any non-empty and non-0 value, will perform +# the benchmark tests (see *_logged_benchmark targets in +# tools/test/BUILD) # # This script can be used by Jenkins parameterized / matrix builds. @@ -98,6 +102,8 @@ PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh" PIP_TEST_TUTORIALS_FLAG="--test_tutorials" ANDROID_CMD="${CI_BUILD_DIR}/builds/android.sh" +BENCHMARK_CMD="${CI_BUILD_DIR}/builds/benchmark.sh" + BAZEL_TARGET="//tensorflow/..." TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data" @@ -129,6 +135,7 @@ echo " TF_BUILD_BAZEL_TARGET=${TF_BUILD_BAZEL_TARGET}" echo " TF_BUILD_BAZEL_CLEAN=${TF_BUILD_BAZEL_CLEAN}" echo " TF_BUILD_SERIAL_TESTS=${TF_BUILD_SERIAL_TESTS}" echo " TF_BUILD_TEST_TUTORIALS=${TF_BUILD_TEST_TUTORIALS}" +echo " TF_BUILD_RUN_BENCHMARKS=${TF_BUILD_RUN_BENCHMARKS}" # Process container type CTYPE=${TF_BUILD_CONTAINER_TYPE} @@ -159,6 +166,13 @@ if [[ -z "$(which docker)" ]]; then fi +# Determine if this is a benchmarks job +RUN_BENCHMARKS=0 +if [[ ! -z "${TF_BUILD_RUN_BENCHMARKS}" ]] && + [[ "${TF_BUILD_RUN_BENCHMARKS}" != "0" ]]; then + RUN_BENCHMARKS=1 +fi + # Process Bazel "-c opt" flag if [[ ${TF_BUILD_IS_OPT} == "no_opt" ]]; then # PIP builds are done only with the -c opt flag @@ -177,6 +191,25 @@ fi # Strip whitespaces from OPT_FLAG OPT_FLAG=$(str_strip "${OPT_FLAG}") + +# Filter out benchmark tests if this is not a benchmarks job +EXTRA_ARGS="" +if [[ "${TF_BUILD_APPEND_ARGUMENTS}" == *"--test_tag_filters="* ]]; then + ITEMS=(${TF_BUILD_APPEND_ARGUMENTS}) + + for ITEM in "${ITEMS[@]}"; do + if [[ ${ITEM} == *"--test_tag_filters="* ]] && + [[ ${ITEM} != *"benchmark-test"* ]]; then + EXTRA_ARGS="${EXTRA_ARGS} ${ITEM},-benchmark-test" + else + EXTRA_ARGS="${EXTRA_ARGS} ${ITEM}" + fi + done +else + EXTRA_ARGS="${EXTRA_ARGS} --test_tag_filters=-benchmark-test" +fi + + # Process PIP install-test option if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || [[ ${TF_BUILD_IS_PIP} == "both" ]]; then @@ -188,7 +221,7 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || if [[ ${CTYPE} == "cpu" ]] || [[ ${CTYPE} == "gpu" ]]; then # Run Bazel NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} "\ -"${TF_BUILD_APPEND_ARGUMENTS} ${BAZEL_TARGET}" +"${EXTRA_ARGS} ${BAZEL_TARGET}" NO_PIP_MAIN_CMD=$(str_strip "${NO_PIP_MAIN_CMD}") if [[ ! -z "${TF_BUILD_SERIAL_TESTS}" ]] && @@ -198,12 +231,12 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || # But the 2nd (test) step will be done serially. BUILD_ONLY_CMD="${BAZEL_BUILD_ONLY_CMD} ${OPT_FLAG} "\ -"${TF_BUILD_APPEND_ARGUMENTS} ${BAZEL_TARGET}" +"${EXTRA_ARGS} ${BAZEL_TARGET}" echo "Build-only command: ${BUILD_ONLY_CMD}" NO_PIP_MAIN_CMD="${BUILD_ONLY_CMD} && "\ "${BAZEL_CMD} ${OPT_FLAG} ${BAZEL_SERIAL_FLAG} "\ -"${TF_BUILD_APPEND_ARGUMENTS} ${BAZEL_TARGET}" +"${EXTRA_ARGS} ${BAZEL_TARGET}" echo "Parallel-build + serial-test command: ${NO_PIP_MAIN_CMD}" fi elif [[ ${CTYPE} == "android" ]]; then @@ -221,8 +254,7 @@ if [[ ${TF_BUILD_IS_PIP} == "pip" ]] || exit 0 fi - PIP_MAIN_CMD="${MAIN_CMD} ${PIP_CMD} ${CTYPE} "\ -"${TF_BUILD_APPEND_ARGUMENTS}" + PIP_MAIN_CMD="${MAIN_CMD} ${PIP_CMD} ${CTYPE} ${EXTRA_AGRS}" # Add command for tutorial test if [[ ! -z "${TF_BUILD_TEST_TUTORIALS}" ]] && @@ -240,7 +272,10 @@ if [[ ${TF_BUILD_IS_PIP} == "pip" ]] || fi fi -if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]]; then + +if [[ ${RUN_BENCHMARKS} == "1" ]]; then + MAIN_CMD="${BENCHMARK_CMD} ${OPT_FLAG}" +elif [[ ${TF_BUILD_IS_PIP} == "no_pip" ]]; then MAIN_CMD="${NO_PIP_MAIN_CMD}" elif [[ ${TF_BUILD_IS_PIP} == "pip" ]]; then MAIN_CMD="${PIP_MAIN_CMD}" @@ -250,7 +285,6 @@ else die "Unrecognized value in TF_BUILD_IS_PIP: \"${TF_BUILD_IS_PIP}\"" fi - # Process Python version if [[ ${TF_BUILD_PYTHON_VERSION} == "python2" ]]; then : @@ -284,8 +318,7 @@ EXTRA_PARAMS="${EXTRA_PARAMS} ${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS}" # TF_BUILD_SERIAL_TESTS=1), are written to a bash script, which is # then called. The name of the script is randomized to make concurrent # builds on the node possible. -RAND_STR=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 8 | head -n 1) -TMP_SCRIPT=/tmp/ci_parameterized_build_${RAND_STR}.sh +TMP_SCRIPT="$(mktemp)_ci_parameterized_build.sh" if [[ "${DO_DOCKER}" == "1" ]]; then # Map the tmp script into the Docker container diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 39583869e20..224762c7b1d 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -18,3 +18,9 @@ set -e pip install sklearn pip3 install scikit-learn + +# Benchmark tests require the following: +pip install psutil +pip3 install psutil +pip install py-cpuinfo +pip3 install py-cpuinfo diff --git a/tensorflow/tools/docker/notebooks/2_getting_started.ipynb b/tensorflow/tools/docker/notebooks/2_getting_started.ipynb index b1809cff30d..f4eb2c9ab5a 100644 --- a/tensorflow/tools/docker/notebooks/2_getting_started.ipynb +++ b/tensorflow/tools/docker/notebooks/2_getting_started.ipynb @@ -159,7 +159,7 @@ " \n", " yhat = tf.matmul(input, weights)\n", " yerror = tf.sub(yhat, target)\n", - " loss = tf.reduce_mean(tf.nn.l2_loss(yerror))\n", + " loss = tf.nn.l2_loss(yerror)\n", " \n", " update_weights = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)\n", " \n", @@ -601,7 +601,7 @@ " # Our target is the y values. They need to be massaged to the right shape.\n", " target = tf.constant(np.transpose([y]).astype(np.float32))\n", " # Weights are a variable. They change every time through the loop.\n", - " # Weights are initialized to random values (gaussian, mean 0, stdev 1)\n", + " # Weights are initialized to random values (gaussian, mean 0, stdev 0.1)\n", " weights = tf.Variable(tf.random_normal([2, 1], 0, 0.1))\n", "\n", " # Initialize all the variables defined above.\n", @@ -617,7 +617,7 @@ " # We are going to minimize the L2 loss. The L2 loss is the sum of the\n", " # squared error for all our estimates of y. This penalizes large errors\n", " # a lot, but small errors only a little.\n", - " loss = tf.reduce_mean(tf.nn.l2_loss(yerror))\n", + " loss = tf.nn.l2_loss(yerror)\n", "\n", " # Perform gradient descent. \n", " # This essentially just updates weights, like weights += grads * mu\n", @@ -824,9 +824,9 @@ "\n", "The first line calculates the L2 loss manually. It's the same as `l2_loss(yerror)`, which is half of the sum of the squared error, so $\\frac{1}{2} \\sum (\\hat{y} - y)^2$. With this code, you can see exactly what the `l2_loss` operation does. It's the total of all the squared differences between the target and our estimates. And minimizing the L2 loss will minimize how much our estimates of $y$ differ from the true values of $y$.\n", "\n", - "The second line calculates $\\sum{x_i (\\hat{y} - y)}$. What is that? It's the partial derivative of the L2 loss, the same thing as what `gradients(loss, weights)` does in the earlier code. Not sure about that? Let's look at it in more detail. The gradient calculation is going to get the partial derivatives of loss with respect to each of the weights so we can change those weights in the direction that will reduce the loss. L2 loss is $\\frac{1}{2} \\sum (\\hat{y} - y)^2$, where $\\hat{y} = w_2 x + w_1$. So, using the chain rule and substituting in for $\\hat{y}$ in the derivative, $\\frac{\\partial}{\\partial w_i} = \\sum{(\\hat{y} - y)\\, x_i}$. `GradientDescentOptimizer` does these calculations automatically for you based on the graph structure.\n", + "The second line calculates $\\begin{bmatrix}\\sum{(\\hat{y} - y)*1} \\\\ \\sum{(\\hat{y} - y)*x_i}\\end{bmatrix}$. What is that? It's the partial derivatives of the L2 loss with respect to $w_1$ and $w_2$, the same thing as what `gradients(loss, weights)` does in the earlier code. Not sure about that? Let's look at it in more detail. The gradient calculation is going to get the partial derivatives of loss with respect to each of the weights so we can change those weights in the direction that will reduce the loss. L2 loss is $\\frac{1}{2} \\sum (\\hat{y} - y)^2$, where $\\hat{y} = w_2 x + w_1$. So, using the chain rule and substituting in for $\\hat{y}$ in the derivative, $\\frac{\\partial}{\\partial w_2} = \\sum{(\\hat{y} - y)\\, *x_i}$ and $\\frac{\\partial}{\\partial w_1} = \\sum{(\\hat{y} - y)\\, *1}$. `GradientDescentOptimizer` does these calculations automatically for you based on the graph structure.\n", "\n", - "The third line is equivalent to `weights -= mu * gradient`, so it subtracts a constant the gradient after scaling by the learning rate (to avoid jumping too far each time, which risks moving in the wrong direction). It's also the same thing that `GradientDescentOptimizer(learning_rate).minimize(loss)` does in the earlier code. Gradient descient updates its first parameter based on the values in the second after scaling by the third, so it's equivalent to the `assign_sub(weights, mu * gradient)`.\n", + "The third line is equivalent to `weights -= mu * gradient`, so it subtracts a constant the gradient after scaling by the learning rate (to avoid jumping too far each time, which risks moving in the wrong direction). It's also the same thing that `GradientDescentOptimizer(learning_rate).minimize(loss)` does in the earlier code. Gradient descent updates its first parameter based on the values in the second after scaling by the third, so it's equivalent to the `assign_sub(weights, mu * gradient)`.\n", "\n", "Hopefully, this other code gives you a better understanding of what the operations we used previously are actually doing. In practice, you'll want to use those high level operators most of the time rather than calculating things yourself. For this toy example and simple network, it's not too bad to compute and apply the gradients yourself from scratch, but things get more complicated with larger networks." ] diff --git a/tensorflow/tools/swig/swig.sh b/tensorflow/tools/swig/swig.sh index c35b2ee3634..367dcb4cd0a 100755 --- a/tensorflow/tools/swig/swig.sh +++ b/tensorflow/tools/swig/swig.sh @@ -14,4 +14,12 @@ # limitations under the License. # ============================================================================== -swig "$@" +# If possible, read swig path out of "swig_path" generated by configure +SWIG=swig +SWIG_PATH=tensorflow/tools/swig/swig_path +if [ -e $SWIG_PATH ]; then + SWIG=`cat $SWIG_PATH` +fi + +# If this line fails, rerun configure to set the path to swig correctly +"$SWIG" "$@" diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD index 45112833cde..23d9cc64440 100644 --- a/tensorflow/tools/test/BUILD +++ b/tensorflow/tools/test/BUILD @@ -3,7 +3,11 @@ package(default_visibility = ["//tensorflow:internal"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load( + "//tensorflow/tools/test:performance.bzl", + "tf_cc_logged_benchmark", + "tf_py_logged_benchmark", +) licenses(["notice"]) # Apache 2.0 @@ -69,6 +73,16 @@ py_binary( # main = "run_and_gather_logs.py", #) +tf_cc_logged_benchmark( + name = "cast_op_benchmark", + target = "//tensorflow/core/kernels:cast_op_test", +) + +tf_py_logged_benchmark( + name = "rnn_op_benchmark", + target = "//tensorflow/python:rnn_test", +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/tools/test/performance.bzl b/tensorflow/tools/test/performance.bzl new file mode 100644 index 00000000000..750d20fdca3 --- /dev/null +++ b/tensorflow/tools/test/performance.bzl @@ -0,0 +1,56 @@ +# -*- Python -*- + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +# Create a benchmark test target of a TensorFlow C++ test (tf_cc_*_test) +def tf_cc_logged_benchmark( + name=None, + target=None, + benchmarks="..", + tags=[], + test_log_output_prefix=""): + if not name: + fail("Must provide a name") + if not target: + fail("Must provide a target") + if (not ":" in target + or not target.startswith("//") + or target.endswith(":all") + or target.endswith(".")): + fail(" ".join(("Target must be a single well-defined test, e.g.,", + "//path/to:test. Received: %s" % target))) + + all_tags = list(set(tags) + \ + set(["benchmark-test", "local", "regression-test"])) + + tf_py_test( + name = name, + tags = all_tags, + srcs = ["//tensorflow/tools/test:run_and_gather_logs.py"], + args = [ + "--test_name=" + target + ], + data = [ + target, + ], + main = "run_and_gather_logs.py", + additional_deps = [ + "//tensorflow/tools/test:run_and_gather_logs" + ]) + +# Create a benchmark test target of a TensorFlow python test (*py_tests) +def tf_py_logged_benchmark( + name=None, + target=None, + benchmarks="..", + tags=[], + test_log_output_prefix=""): + # For now generating a py benchmark is the same as generating a C++ + # benchmark target. In the future this may change, so we have + # two macros just in case + tf_cc_logged_benchmark( + name=name, + target=target, + benchmarks=benchmarks, + tags=tags, + test_log_output_prefix=test_log_output_prefix) diff --git a/tensorflow/tools/test/run_and_gather_logs.py b/tensorflow/tools/test/run_and_gather_logs.py index 40a8542a463..9c50138a7bd 100644 --- a/tensorflow/tools/test/run_and_gather_logs.py +++ b/tensorflow/tools/test/run_and_gather_logs.py @@ -44,6 +44,7 @@ from google.protobuf import text_format from tensorflow.core.util import test_log_pb2 from tensorflow.tools.test import run_and_gather_logs_lib + FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string("test_name", "", """Test target to run.""") @@ -92,7 +93,7 @@ def main(unused_args): else: output_path = os.path.abspath(FLAGS.test_log_output) tf.gfile.GFile(output_path, "w").write(serialized_test_results) - print("Test results written to: %s" % output_path) + tf.logging.info("Test results written to: %s" % output_path) if __name__ == "__main__": diff --git a/tensorflow/tools/test/run_and_gather_logs_lib.py b/tensorflow/tools/test/run_and_gather_logs_lib.py index afe8f210cc4..d6bc10dec97 100644 --- a/tensorflow/tools/test/run_and_gather_logs_lib.py +++ b/tensorflow/tools/test/run_and_gather_logs_lib.py @@ -28,16 +28,48 @@ import time import tensorflow as tf from google.protobuf import text_format - from tensorflow.core.util import test_log_pb2 from tensorflow.tools.test import system_info_lib +def get_git_commit_sha(): + """Get git commit SHA for this build. + + Attempt to get the SHA from environment variable GIT_COMMIT, which should + be available on Jenkins build agents. + + Returns: + SHA hash of the git commit used for the build, if available + """ + + return os.getenv("GIT_COMMIT") + + def process_test_logs(test_name, test_args, start_time, run_time, log_files): + """Gather test information and put it in a TestResults proto. + + Args: + test_name: A unique bazel target, e.g. "//path/to:test" + test_args: A string containing all arguments to run the target with. + + start_time: Test starting time (epoch) + run_time: Wall time that the test ran for + log_files: Paths to the log files + + Returns: + A TestResults proto + """ + results = test_log_pb2.TestResults() results.target = test_name results.start_time = start_time results.run_time = run_time + + # Gather source code information + git_sha = get_git_commit_sha() + if git_sha: + results.commit_id.hash = git_sha + results.entries.CopyFrom(process_benchmarks(log_files)) results.run_configuration.argument.extend(test_args) results.machine_configuration.CopyFrom( diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index c787f1ea71d..d15688fd26d 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -6,7 +6,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): native.new_http_archive( name = "gmock_archive", - url = "https://googlemock.googlecode.com/files/gmock-1.7.0.zip", + url = "https://archive.openswitch.net/gmock-1.7.0.zip", sha256 = "26fcbb5925b74ad5fc8c26b0495dfc96353f4d553492eb97e85a8a6d2f43095b", build_file = path_prefix + "google/protobuf/gmock.BUILD", ) @@ -43,8 +43,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): native.new_http_archive( name = "png_archive", - url = "https://storage.googleapis.com/libpng-public-archive/libpng-1.2.53.tar.gz", - sha256 = "e05c9056d7f323088fd7824d8c6acc03a4a758c4b4916715924edc5dd3223a72", + url = "https://github.com/glennrp/libpng/archive/v1.2.53.zip", + sha256 = "c35bcc6387495ee6e757507a68ba036d38ad05b415c2553b3debe2a57647a692", build_file = path_prefix + "png.BUILD", ) @@ -74,7 +74,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): native.git_repository( name = "grpc", - commit = "73979f4", + commit = "3d62fc6", init_submodules = True, remote = "https://github.com/grpc/grpc.git", ) diff --git a/third_party/gpus/crosstool/CROSSTOOL b/third_party/gpus/crosstool/CROSSTOOL index dfde7cd216a..a9f26f57102 100644 --- a/third_party/gpus/crosstool/CROSSTOOL +++ b/third_party/gpus/crosstool/CROSSTOOL @@ -10,6 +10,10 @@ default_toolchain { cpu: "piii" toolchain_identifier: "local_linux" } +default_toolchain { + cpu: "arm" + toolchain_identifier: "local_linux" +} default_toolchain { cpu: "darwin" toolchain_identifier: "local_darwin" diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc index a67b0390005..04ab50ca86c 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc @@ -43,14 +43,16 @@ import re import sys import pipes -CURRENT_DIR = os.path.dirname(sys.argv[0]) +# "configure" uses the specific format to substitute the following string. +# If you change it, make sure you modify "configure" as well. CPU_COMPILER = ('/usr/bin/gcc') -NVCC_PATH = CURRENT_DIR + '/../../../cuda/bin/nvcc' GCC_HOST_COMPILER_PATH = ('/usr/bin/gcc') + +CURRENT_DIR = os.path.dirname(sys.argv[0]) +NVCC_PATH = CURRENT_DIR + '/../../../cuda/bin/nvcc' LLVM_HOST_COMPILER_PATH = ('/usr/bin/gcc') PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) - def Log(s): print 'gpus/crosstool: {0}'.format(s)