diff --git a/configure b/configure index 6360641be2c..48a4594da63 100755 --- a/configure +++ b/configure @@ -56,7 +56,7 @@ rm -f .tf_configure.bazelrc touch .tf_configure.bazelrc touch .bazelrc sed_hyphen_i "/tf_configure/d" .bazelrc -echo "import .tf_configure.bazelrc" >> .bazelrc +echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc # Delete any leftover BUILD files from the Makefile build, which would interfere # with Bazel parsing. @@ -284,6 +284,7 @@ export TF_NEED_CUDA write_action_env_to_bazelrc "TF_NEED_CUDA" "$TF_NEED_CUDA" export TF_NEED_OPENCL +write_action_env_to_bazelrc "TF_NEED_OPENCL" "$TF_NEED_OPENCL" if [ "$TF_NEED_CUDA" == "1" ]; then while [[ "$TF_CUDA_CLANG" == "" ]]; do @@ -547,6 +548,7 @@ while true; do fi if [ -e "$HOST_CXX_COMPILER" ]; then export HOST_CXX_COMPILER + write_action_env_to_bazelrc "HOST_CXX_COMPILER" "$HOST_CXX_COMPILER" break fi echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2 @@ -570,6 +572,7 @@ while true; do fi if [ -e "$HOST_C_COMPILER" ]; then export HOST_C_COMPILER + write_action_env_to_bazelrc "HOST_C_COMPILER" "$HOST_C_COMPILER" break fi echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2 @@ -600,6 +603,7 @@ while true; do if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then export COMPUTECPP_TOOLKIT_PATH + write_action_env_to_bazelrc "COMPUTECPP_TOOLKIT_PATH" "$COMPUTECPP_TOOLKIT_PATH" break fi echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found" diff --git a/tensorflow/BUILD b/tensorflow/BUILD index e437987112b..b98be57ec08 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -351,9 +351,24 @@ filegroup( # ------------------------------------------- cc_binary( name = "libtensorflow.so", + linkopts = select({ + "//tensorflow:darwin": [ + "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file + "//tensorflow/c:exported_symbols.lds", + ], + "//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-s", + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "//tensorflow/c:version_script.lds", + ], + }), linkshared = 1, deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:exported_symbols.lds", + "//tensorflow/c:version_script.lds", "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 0019dfeeb13..6e39deee636 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -45,6 +45,14 @@ tf_cuda_library( }), ) +exports_files( + [ + "version_script.lds", + "exported_symbols.lds", + ], + visibility = ["//visibility:public"], +) + tf_cuda_library( name = "tf_status_helper", srcs = ["tf_status_helper.cc"], diff --git a/tensorflow/c/exported_symbols.lds b/tensorflow/c/exported_symbols.lds new file mode 100644 index 00000000000..a14bdaa48be --- /dev/null +++ b/tensorflow/c/exported_symbols.lds @@ -0,0 +1 @@ +_TF_* diff --git a/tensorflow/c/version_script.lds b/tensorflow/c/version_script.lds new file mode 100644 index 00000000000..455bd7362bb --- /dev/null +++ b/tensorflow/c/version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + # Export symbols in c_api.h. + global: + TF_*; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 85ef9560bbf..59a45538a72 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -52,7 +52,8 @@ const char kUsageHeader[] = "header file that gives access to the functionality in the object file.\n" "A typical invocation looks like this:\n" "\n" - " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n" + " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt " + "--cpp_class=\"mynamespace::MyComputation\"\n" "\n"; Status ReadProtoFile(const string& kind, const string& fname, @@ -73,6 +74,9 @@ void ParseTensorId(const string& name, TensorId* id) { Status Main(const MainFlags& flags) { // Process config. Config config; + if (flags.config.empty()) { + return errors::InvalidArgument("Must specify --config"); + } TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config)); TF_RETURN_IF_ERROR(ValidateConfig(config)); if (flags.dump_fetch_nodes) { @@ -85,6 +89,9 @@ Status Main(const MainFlags& flags) { } // Read and initialize the graph. + if (flags.graph.empty()) { + return errors::InvalidArgument("Must specify --graph"); + } GraphDef graph_def; TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def)); std::unique_ptr graph; @@ -101,6 +108,9 @@ Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, StringPiece(obj.data(), obj.size()))); HeaderOpts header_opts; + if (flags.cpp_class.empty()) { + return errors::InvalidArgument("Must specify --cpp_class"); + } TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name, &header_opts.namespaces)); string header; @@ -131,12 +141,16 @@ int main(int argc, char** argv) { QCHECK(parsed_flags_ok) << "\n" << usage; tensorflow::port::InitMain(usage.c_str(), &argc, &argv); - QCHECK(argc == 1 && !flags.config.empty() && - (flags.dump_fetch_nodes || - (!flags.graph.empty() && !flags.entry_point.empty()))) - << "\n" - << usage; - - TF_QCHECK_OK(tensorflow::tfcompile::Main(flags)); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + tensorflow::Status status = tensorflow::tfcompile::Main(flags); + if (status.code() == tensorflow::error::INVALID_ARGUMENT) { + std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n" + << usage; + return 1; + } else { + TF_QCHECK_OK(status); + } return 0; } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 03e255e6b84..0592e3d4b19 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -305,6 +305,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "spacetobatch_op_test", + size = "medium", + srcs = ["spacetobatch_op_test.py"], + shard_count = 3, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "ternary_ops_test", size = "small", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9efdaee7ab6..7221a0a3c74 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -107,6 +107,12 @@ class BinaryOpsTest(XLATestCase): np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-75, -48, -21, 0], dtype=dtype)) + self._testBinary( + gen_nn_ops._elu_grad, + np.array([1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), + expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) + self._testBinary( gen_nn_ops._relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index a0cd905f173..7d91594db00 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -218,12 +218,11 @@ class OpTest : public ::testing::Test { static constexpr int kDefaultMaxRank = 5; static constexpr int64 kDefaultMaxDimensionSize = 20LL; - // Returns a random dimension size. + // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); // Returns a random shape. The tensor has rank in the range [min_rank, - // max_rank). - // Each dimension has size [0, kDefaultMaxDimensionSize]. + // max_rank). Each dimension has size [min_size, max_size). std::vector RandomDims(int min_rank = 0, int max_rank = kDefaultMaxRank, int64 min_size = 0, @@ -668,6 +667,9 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test"; return; } + for (const Tensor& expected : expected_outputs) { + VLOG(1) << "Expected: " << expected.DebugString(); + } VLOG(1) << "Running test graph"; TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs)); @@ -877,6 +879,79 @@ TEST_F(OpTest, BatchMatMul) { }); } +TEST_F(OpTest, BatchToSpace) { + Repeatedly([this]() { + const int num_block_dims = 2; + std::vector block_dims = + RandomDims(num_block_dims, num_block_dims, 0, 5); + int64 block_size = RandomDim(0, 4); + + std::vector input_dims(1 + num_block_dims + 1); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[0] *= block_size; + input_dims[1 + i] = block_dims[i]; + } + input_dims[1 + num_block_dims] = RandomDim(); + + std::vector crop_vals; + std::uniform_int_distribution distribution(0, 4); + for (int i = 0; i < num_block_dims; ++i) { + // Chooses crop values; does not always choose legal values. + crop_vals.push_back(distribution(generator())); + crop_vals.push_back(distribution(generator())); + } + Tensor crops; + CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(crops) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); + }); +} + +TEST_F(OpTest, BatchToSpaceND) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(1, 3, 0, 5); + int num_block_dims = block_dims.size(); + std::vector remaining_dims = RandomDims(0, 3); + std::vector block_multipliers = + RandomDims(block_dims.size(), block_dims.size(), 0, 4); + + std::vector input_dims(1 + num_block_dims + remaining_dims.size()); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[0] *= block_dims[i]; + } + std::copy(block_multipliers.begin(), block_multipliers.end(), + input_dims.begin() + 1); + std::copy(remaining_dims.begin(), remaining_dims.end(), + input_dims.begin() + 1 + num_block_dims); + + std::vector crop_vals; + std::uniform_int_distribution distribution(0, 3); + for (int i = 0; i < num_block_dims; ++i) { + // Chooses crop values; does not always choose legal values. + crop_vals.push_back(distribution(generator())); + crop_vals.push_back(distribution(generator())); + } + Tensor crops; + CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BatchToSpaceND") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) + .Input(crops) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, BiasAdd) { Repeatedly([this]() { auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); @@ -1214,6 +1289,23 @@ TEST_F(OpTest, DynamicStitch) { }); } +TEST_F(OpTest, Elu) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Elu").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, EluGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Equal) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); @@ -2019,6 +2111,87 @@ TEST_F(OpTest, SoftplusGrad) { }); } +TEST_F(OpTest, SpaceToBatch) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(4, 4, 0, 5); + const int num_block_dims = 2; + int64 block_size = RandomDim(0, 4); + + std::vector input_dims(1 + num_block_dims + 1); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[1 + i] = block_dims[i] * block_size; + } + input_dims[1 + num_block_dims] = RandomDim(); + + std::vector padding_vals; + std::uniform_int_distribution distribution(0, 7); + for (int i = 0; i < num_block_dims; ++i) { + int64 pad_before; + int64 pad_after; + do { + pad_before = distribution(generator()); + pad_after = distribution(generator()); + } while (pad_before + pad_after > input_dims[1 + i]); + input_dims[1 + i] -= pad_before + pad_after; + padding_vals.push_back(pad_before); + padding_vals.push_back(pad_after); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(paddings) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); + }); +} + +TEST_F(OpTest, SpaceToBatchND) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(1, 3, 0, 5); + int num_block_dims = block_dims.size(); + std::vector remaining_dims = RandomDims(0, 3); + std::vector block_multipliers = + RandomDims(block_dims.size(), block_dims.size(), 0, 4); + + std::vector input_dims(1 + num_block_dims + remaining_dims.size()); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[1 + i] = block_dims[i] * block_multipliers[i]; + } + std::copy(remaining_dims.begin(), remaining_dims.end(), + input_dims.begin() + 1 + num_block_dims); + + std::vector padding_vals; + std::uniform_int_distribution distribution(0, 7); + for (int i = 0; i < num_block_dims; ++i) { + int64 pad_before; + int64 pad_after; + do { + pad_before = distribution(generator()); + pad_after = distribution(generator()); + } while (pad_before + pad_after > input_dims[1 + i]); + input_dims[1 + i] -= pad_before + pad_after; + padding_vals.push_back(pad_before); + padding_vals.push_back(pad_after); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("SpaceToBatchND") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) + .Input(paddings) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, SparseMatMul) { Repeatedly([this]() { int64 x = RandomDim(); diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py new file mode 100644 index 00000000000..9c3b86c84b2 --- /dev/null +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -0,0 +1,266 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for SpaceToBatch and BatchToSpace ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import test + + +def space_to_batch_direct(input_array, block_shape, paddings): + """Direct Python implementation of space-to-batch conversion. + + This is used for tests only. + + Args: + input_array: N-D array + block_shape: 1-D array of shape [num_block_dims]. + paddings: 2-D array of shape [num_block_dims, 2]. + + Returns: + Converted tensor. + """ + input_array = np.array(input_array) + block_shape = np.array(block_shape) + num_block_dims = len(block_shape) + paddings = np.array(paddings).reshape((len(block_shape), 2)) + + padded = np.pad(input_array, + pad_width=([[0, 0]] + list(paddings) + [[0, 0]] * + (input_array.ndim - 1 - num_block_dims)), + mode="constant") + reshaped_padded_shape = [input_array.shape[0]] + output_shape = [input_array.shape[0] * np.prod(block_shape)] + for block_dim, block_shape_value in enumerate(block_shape): + reduced_size = padded.shape[block_dim + 1] // block_shape_value + reshaped_padded_shape.append(reduced_size) + output_shape.append(reduced_size) + reshaped_padded_shape.append(block_shape_value) + reshaped_padded_shape.extend(input_array.shape[num_block_dims + 1:]) + output_shape.extend(input_array.shape[num_block_dims + 1:]) + + reshaped_padded = padded.reshape(reshaped_padded_shape) + permuted_reshaped_padded = np.transpose(reshaped_padded, ( + list(np.arange(num_block_dims) * 2 + 2) + [0] + + list(np.arange(num_block_dims) * 2 + 1) + list( + np.arange(input_array.ndim - num_block_dims - 1) + 1 + num_block_dims + * 2))) + return permuted_reshaped_padded.reshape(output_shape) + + +class SpaceToBatchTest(XLATestCase): + """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" + + def _testPad(self, inputs, paddings, block_size, outputs): + with self.test_session() as sess, self.test_scope(): + for dtype in self.float_types: + # outputs = space_to_batch(inputs) + placeholder = array_ops.placeholder(dtype) + x_tf = gen_array_ops._space_to_batch( + placeholder, paddings, block_size=block_size) + self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + # inputs = batch_to_space(outputs) + x_tf = gen_array_ops._batch_to_space( + placeholder, paddings, block_size=block_size) + self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + + def _testOne(self, inputs, block_size, outputs): + paddings = np.zeros((2, 2), dtype=np.int32) + self._testPad(inputs, paddings, block_size, outputs) + + # [1, 2, 2, 1] <-> [4, 1, 1, 1] + def testSmallInput2x2(self): + x_np = [[[[1], [2]], [[3], [4]]]] + block_size = 2 + x_out = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + self._testOne(x_np, block_size, x_out) + + # [1, 2, 2, 1] <-> [1, 3, 3, 1] (padding) <-> [9, 1, 1, 1] + def testSmallInput2x2Pad1x0(self): + x_np = [[[[1], [2]], [[3], [4]]]] + paddings = np.array([[1, 0], [1, 0]], dtype=np.int32) + block_size = 3 + x_out = [[[[0]]], [[[0]]], [[[0]]], [[[0]]], [[[1]]], [[[2]]], [[[0]]], + [[[3]]], [[[4]]]] + self._testPad(x_np, paddings, block_size, x_out) + + # Test with depth larger than 1. + # [1, 2, 2, 3] <-> [4, 1, 1, 3] + def testDepthInput2x2(self): + x_np = [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]] + block_size = 2 + x_out = [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] + self._testOne(x_np, block_size, x_out) + + # Test for larger input dimensions. + # [1, 4, 4, 1] <-> [4, 2, 2, 1] + def testLargerInput2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]], + [[9], [10], [11], [12]], [[13], [14], [15], [16]]]] + block_size = 2 + x_out = [[[[1], [3]], [[9], [11]]], [[[2], [4]], [[10], [12]]], + [[[5], [7]], [[13], [15]]], [[[6], [8]], [[14], [16]]]] + self._testOne(x_np, block_size, x_out) + + # Test with batch larger than 1. + # [2, 2, 4, 1] <-> [8, 1, 2, 1] + def testBatchInput2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]]], + [[[9], [10], [11], [12]], [[13], [14], [15], [16]]]] + block_size = 2 + x_out = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], + [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] + self._testOne(x_np, block_size, x_out) + + # Tests for larger input spatial dimensions AND batch larger than 1, to ensure + # that elements are correctly laid out spatially and properly interleaved + # along the batch dimension. + # [2, 4, 4, 1] <-> [8, 2, 2, 1] + def testLargerInputBatch2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]], + [[9], [10], [11], [12]], [[13], [14], [15], [16]]], + [[[17], [18], [19], [20]], [[21], [22], [23], [24]], + [[25], [26], [27], [28]], [[29], [30], [31], [32]]]] + x_out = [[[[1], [3]], [[9], [11]]], [[[17], [19]], [[25], [27]]], + [[[2], [4]], [[10], [12]]], [[[18], [20]], [[26], [28]]], + [[[5], [7]], [[13], [15]]], [[[21], [23]], [[29], [31]]], + [[[6], [8]], [[14], [16]]], [[[22], [24]], [[30], [32]]]] + block_size = 2 + self._testOne(x_np, block_size, x_out) + + +class SpaceToBatchNDTest(XLATestCase): + """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" + + def _testPad(self, inputs, block_shape, paddings, outputs): + block_shape = np.array(block_shape) + paddings = np.array(paddings).reshape((len(block_shape), 2)) + with self.test_session() as sess, self.test_scope(): + for dtype in self.float_types: + placeholder = array_ops.placeholder(dtype) + # outputs = space_to_batch(inputs) + x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings) + self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + # inputs = batch_to_space(outputs) + placeholder = array_ops.placeholder(dtype) + x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings) + self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + + def _testDirect(self, input_shape, block_shape, paddings): + inputs = np.arange(np.prod(input_shape), dtype=np.float32) + inputs = inputs.reshape(input_shape) + self._testPad(inputs, block_shape, paddings, + space_to_batch_direct(inputs, block_shape, paddings)) + + def testZeroBlockDimsZeroRemainingDims(self): + self._testPad( + inputs=[1, 2], + block_shape=[], + paddings=[], + outputs=[1, 2],) + + def testZeroBlockDimsOneRemainingDim(self): + self._testPad( + inputs=[[1, 2], [3, 4]], + block_shape=[], + paddings=[], + outputs=[[1, 2], [3, 4]]) + + # Same thing, but with a no-op block dim. + self._testPad( + inputs=[[1, 2], [3, 4]], + block_shape=[1], + paddings=[[0, 0]], + outputs=[[1, 2], [3, 4]]) + + def testZeroBlockDimsTwoRemainingDims(self): + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[], + paddings=[], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + # Same thing, but with a no-op block dim. + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[1], + paddings=[[0, 0]], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + # Same thing, but with two no-op block dims. + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[1, 1], + paddings=[[0, 0], [0, 0]], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + def testOneBlockDimZeroRemainingDims(self): + self._testPad( + inputs=[[1, 2, 3], [4, 5, 6]], + block_shape=[2], + paddings=[1, 0], + outputs=[[0, 2], [0, 5], [1, 3], [4, 6]]) + + def testOneBlockDimOneRemainingDim(self): + self._testPad( + inputs=[[[1, 11], [2, 21], [3, 31]], [[4, 41], [5, 51], [6, 61]]], + block_shape=[2], + paddings=[1, 0], + outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]], + [[4, 41], [6, 61]]]) + + def testDirect(self): + # Test with zero-size remaining dimension. + self._testDirect( + input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]]) + + # Test with zero-size blocked dimension. + self._testDirect( + input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]]) + + # Test with padding up from zero size. + self._testDirect( + input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]]) + + self._testDirect( + input_shape=[3, 3, 4, 5, 2], + block_shape=[3, 4, 2], + paddings=[[1, 2], [0, 0], [3, 0]]) + + self._testDirect( + input_shape=[3, 3, 4, 5, 2], + block_shape=[3, 4, 2, 2], + paddings=[[1, 2], [0, 0], [3, 0], [0, 0]]) + + self._testDirect( + input_shape=[3, 2, 2, 3, 4, 5, 2, 5], + block_shape=[1, 1, 3, 4, 2, 2], + paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]]) + + self._testDirect( + input_shape=[3, 2, 2, 3, 4, 5, 2, 5], + block_shape=[1, 1, 3, 4, 2, 2, 1], + paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0], [0, 0]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 1e85d3a2c8b..3f324d1071e 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -209,6 +209,11 @@ class UnaryOpsTest(XLATestCase): [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.elu, + np.array([[-1, 0, 1]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.relu, np.array([[-1, 1]], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 53aa749a0a9..44ff13ca34e 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -35,6 +35,9 @@ Status BackwardsConstAnalysis(const Graph& g, {"Any", "reduction_indices"}, {"ArgMax", "dimension"}, {"AvgPoolGrad", "orig_input_shape"}, + {"BatchToSpace", "crops"}, + {"BatchToSpaceND", "block_shape"}, + {"BatchToSpaceND", "crops"}, {"BroadcastGradientArgs", "s0"}, {"BroadcastGradientArgs", "s1"}, {"Concat", "concat_dim"}, @@ -69,6 +72,9 @@ Status BackwardsConstAnalysis(const Graph& g, {"ReverseV2", "axis"}, {"Slice", "begin"}, {"Slice", "size"}, + {"SpaceToBatch", "paddings"}, + {"SpaceToBatchND", "block_shape"}, + {"SpaceToBatchND", "paddings"}, {"Split", "split_dim"}, {"SplitV", "split_dim"}, {"SplitV", "size_splits"}, diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 2ee80a41e82..14d2a72f7ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -15,6 +15,7 @@ tf_kernel_library( srcs = [ "aggregate_ops.cc", "batch_matmul_op.cc", + "batchtospace_op.cc", "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", @@ -26,6 +27,7 @@ tf_kernel_library( "depthwise_conv_ops.cc", "diag_op.cc", "dynamic_stitch_op.cc", + "elu_op.cc", "fill_op.cc", "function_ops.cc", "identity_op.cc", @@ -49,6 +51,7 @@ tf_kernel_library( "shape_op.cc", "slice_op.cc", "softmax_op.cc", + "spacetobatch_op.cc", "split_op.cc", "strided_slice_op.cc", "tile_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc new file mode 100644 index 00000000000..eb4bd47ee50 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +void BatchToSpace(XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, DataType input_dtype, + const TensorShape& input_tensor_shape, + gtl::ArraySlice block_shape, + const xla::Literal& crops) { + const int input_rank = input_tensor_shape.dims(); + const gtl::InlinedVector input_shape = + input_tensor_shape.dim_sizes(); + const int block_rank = block_shape.size(); + + OP_REQUIRES( + ctx, input_rank >= 1 + block_rank, + errors::InvalidArgument("input rank should be >= ", 1 + block_rank, + " instead of ", input_rank)); + gtl::ArraySlice remainder_shape(input_shape); + remainder_shape.remove_prefix(1 + block_rank); + + OP_REQUIRES( + ctx, + xla::ShapeUtil::Rank(crops.shape()) == 2 && + block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && + 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), + errors::InvalidArgument("crops should have shape [", block_rank, + ", 2] instead of ", + xla::ShapeUtil::HumanString(crops.shape()))); + + xla::ComputationBuilder* b = ctx->builder(); + const int64 batch_size = input_shape[0]; + + // Compute the product of the block_shape values. + int64 block_num_elems = 1; + for (int i = 0; i < block_rank; ++i) { + block_num_elems *= block_shape[i]; + } + OP_REQUIRES(ctx, block_num_elems > 0, + errors::InvalidArgument( + "The product of the block dimensions must be positive")); + + // 1. Reshape `input` to `reshaped` of shape: + // [block_shape[0], ..., block_shape[M-1], + // batch / prod(block_shape), + // input_shape[1], ..., input_shape[N-1]] + + OP_REQUIRES( + ctx, batch_size % block_num_elems == 0, + errors::InvalidArgument("Input batch dimension (", batch_size, + ") is not divisible by product of block sizes (", + block_num_elems, ")")); + std::vector reshaped_shape(input_rank + block_rank); + std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin()); + reshaped_shape[block_rank] = batch_size / block_num_elems; + std::copy(input_shape.begin() + 1, input_shape.end(), + reshaped_shape.begin() + block_rank + 1); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + + // 2. Permute dimensions of `reshaped` to produce `permuted` of shape + // [batch / prod(block_shape), + // + // input_shape[1], block_shape[0], + // ..., + // input_shape[M], block_shape[M-1], + // + // input_shape[M+1], ..., input_shape[N-1]] + std::vector permutation(reshaped_shape.size()); + permutation[0] = block_rank; + for (int i = 0; i < block_rank; ++i) { + permutation[1 + 2 * i] = block_rank + 1 + i; + permutation[1 + 2 * i + 1] = i; + } + std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), + 1 + block_rank * 2); + xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation); + + // 3. Reshape `permuted` to produce `reshaped_permuted` of shape + // [batch / prod(block_shape), + // + // input_shape[1] * block_shape[0], + // ..., + // input_shape[M] * block_shape[M-1], + // + // input_shape[M+1], + // ..., + // input_shape[N-1]] + std::vector reshaped_permuted_shape(input_rank); + reshaped_permuted_shape[0] = batch_size / block_num_elems; + for (int i = 0; i < block_rank; ++i) { + reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + reshaped_permuted_shape.begin() + 1 + block_rank); + + xla::ComputationDataHandle reshaped_permuted = + b->Reshape(permuted, reshaped_permuted_shape); + + // 4. Crop the start and end of dimensions `[1, ..., M]` of + // `reshaped_permuted` according to `crops` to produce the output of shape: + // [batch / prod(block_shape), + // + // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], + // ..., + // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], + // + // input_shape[M+1], ..., input_shape[N-1]] + std::vector start_indices(input_rank, 0); + std::vector end_indices = reshaped_permuted_shape; + for (int i = 0; i < block_rank; ++i) { + int64 crop_start = xla::LiteralUtil::Get(crops, {i, 0}); + int64 crop_end = xla::LiteralUtil::Get(crops, {i, 1}); + OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0, + errors::InvalidArgument("Crops must be non-negative")); + start_indices[1 + i] = crop_start; + end_indices[1 + i] -= crop_end; + OP_REQUIRES( + ctx, start_indices[1 + i] <= end_indices[1 + i], + errors::InvalidArgument( + "Cropped size must be non-negative: start: ", crop_start, + " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); + } + xla::ComputationDataHandle output = + b->Slice(reshaped_permuted, start_indices, end_indices); + ctx->SetOutput(0, output); +} + +class BatchToSpaceNDOp : public XlaOpKernel { + public: + explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + std::vector block_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); + + xla::Literal crops; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops)); + + BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + block_shape, crops); + } +}; +REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp); + +class BatchToSpaceOp : public XlaOpKernel { + public: + explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::Literal crops; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops)); + + BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + {block_size_, block_size_}, crops); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc new file mode 100644 index 00000000000..62a5e1bd421 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of XLA Elu Ops + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class EluOp : public XlaOpKernel { + public: + explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto pred = b->Gt(ctx->Input(0), zero); + const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); + } +}; + +class EluGradOp : public XlaOpKernel { + public: + explicit EluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return lhs * (1 + rhs). + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto grad = ctx->Input(0); + const auto activation = ctx->Input(1); + const auto exp_grad = b->Mul(grad, b->Add(activation, one)); + const auto pred = b->Gt(activation, zero); + ctx->SetOutput(0, b->Select(pred, grad, exp_grad)); + } +}; + +REGISTER_XLA_OP(Name("Elu"), EluOp); +REGISTER_XLA_OP(Name("EluGrad"), EluGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc new file mode 100644 index 00000000000..f15b354cb26 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -0,0 +1,190 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +void SpaceToBatch(XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, DataType input_dtype, + const TensorShape& input_tensor_shape, + gtl::ArraySlice block_shape, + const xla::Literal& paddings) { + const int input_rank = input_tensor_shape.dims(); + const gtl::InlinedVector input_shape = + input_tensor_shape.dim_sizes(); + const int block_rank = block_shape.size(); + + OP_REQUIRES( + ctx, input_rank >= 1 + block_rank, + errors::InvalidArgument("input rank should be >= ", 1 + block_rank, + " instead of ", input_rank)); + gtl::ArraySlice remainder_shape(input_shape); + remainder_shape.remove_prefix(1 + block_rank); + + OP_REQUIRES( + ctx, + xla::ShapeUtil::Rank(paddings.shape()) == 2 && + block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && + 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), + errors::InvalidArgument("paddings should have shape [", block_rank, + ", 2] instead of ", + xla::ShapeUtil::HumanString(paddings.shape()))); + + xla::ComputationBuilder* b = ctx->builder(); + + // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the + // input according to `paddings` to produce `padded` of shape `padded_shape`. + xla::PaddingConfig padding_config; + std::vector padded_shape(input_shape.begin(), input_shape.end()); + int64 block_num_elems = 1LL; + padding_config.add_dimensions(); // Don't pad the batch dimension. + for (int i = 0; i < block_rank; ++i) { + auto* dim = padding_config.add_dimensions(); + int64 pad_start = xla::LiteralUtil::Get(paddings, {i, 0}); + int64 pad_end = xla::LiteralUtil::Get(paddings, {i, 1}); + OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0, + errors::InvalidArgument("Paddings must be non-negative")); + dim->set_edge_padding_low(pad_start); + dim->set_edge_padding_high(pad_end); + padded_shape[1 + i] += pad_start + pad_end; + block_num_elems *= block_shape[i]; + } + // Don't pad the remainder dimensions. + for (int i = 0; i < remainder_shape.size(); ++i) { + padding_config.add_dimensions(); + } + OP_REQUIRES(ctx, block_num_elems > 0, + errors::InvalidArgument( + "The product of the block dimensions must be positive")); + + xla::ComputationDataHandle padded = + b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); + + // 2. Reshape `padded` to `reshaped_padded` of shape: + // + // [batch] + + // [padded_shape[1] / block_shape[0], + // block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1], + // block_shape[M-1]] + + // remaining_shape + const int64 batch_size = input_shape[0]; + std::vector reshaped_padded_shape(input_rank + block_rank); + reshaped_padded_shape[0] = batch_size; + for (int i = 0; i < block_rank; ++i) { + OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0, + errors::InvalidArgument("padded_shape[", 1 + i, + "]=", padded_shape[1 + i], + " is not divisible by block_shape[", i, + "]=", block_shape[i])); + + reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i]; + reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + reshaped_padded_shape.begin() + 1 + 2 * block_rank); + + xla::ComputationDataHandle reshaped_padded = + b->Reshape(padded, reshaped_padded_shape); + + // 3. Permute dimensions of `reshaped_padded` to produce + // `permuted_reshaped_padded` of shape: + // + // block_shape + + // [batch] + + // [padded_shape[1] / block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1]] + + // remaining_shape + std::vector permutation(reshaped_padded_shape.size()); + for (int i = 0; i < block_rank; ++i) { + permutation[i] = 1 + 2 * i + 1; + permutation[block_rank + 1 + i] = 1 + 2 * i; + } + permutation[block_rank] = 0; + std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), + 1 + block_rank * 2); + xla::ComputationDataHandle permuted_reshaped_padded = + b->Transpose(reshaped_padded, permutation); + + // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the + // batch dimension, producing an output tensor of shape: + // + // [batch * prod(block_shape)] + + // [padded_shape[1] / block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1]] + + // remaining_shape + // Determine the length of the prefix of block dims that can be combined + // into the batch dimension due to having no padding and block_shape=1. + std::vector output_shape(input_rank); + output_shape[0] = batch_size * block_num_elems; + for (int i = 0; i < block_rank; ++i) { + output_shape[1 + i] = padded_shape[1 + i] / block_shape[i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + output_shape.begin() + 1 + block_rank); + + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped_padded, output_shape); + ctx->SetOutput(0, output); +} + +class SpaceToBatchNDOp : public XlaOpKernel { + public: + explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + std::vector block_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); + + xla::Literal paddings; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings)); + + SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + block_shape, paddings); + } +}; +REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp); + +class SpaceToBatchOp : public XlaOpKernel { + public: + explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::Literal paddings; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings)); + + SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + {block_size_, block_size_}, paddings); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 53dcdec7a25..a022de36a26 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -186,6 +186,31 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, + xla::Literal* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + switch (literal.shape().element_type()) { + case xla::S32: + out->Clear(); + *out->mutable_shape() = literal.shape(); + out->mutable_shape()->set_element_type(xla::S64); + for (int32 x : literal.s32s()) { + out->add_s64s(x); + } + return Status::OK(); + + case xla::S64: + out->Swap(&literal); + return Status::OK(); + + default: + return errors::InvalidArgument( + "Invalid argument to ConstantInputAsInt64Literal: ", + xla::ShapeUtil::HumanString(literal.shape())); + } +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 60e3b59d32a..f97e07bea5d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -110,6 +110,9 @@ class XlaOpKernelContext { // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); + // Converts a constant int32 or int64 Tensor into an xla int64 Literal. + Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b9118fab254..695e4e7f079 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -594,8 +594,10 @@ cc_test( deps = [ ":buffer_assignment", ":computation_tracker", + ":copy_insertion", ":cpu_plugin", ":hlo", + ":hlo_ordering", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 27a1c0fec88..0969cff39ae 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -868,7 +868,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { computation->root_instruction()->dimensions(); EXPECT_EQ(1, broadcast_dims.size()); EXPECT_TRUE(broadcast_dims[0] == 1 || broadcast_dims[0] == 2 || - broadcast_dims[3] == 3); + broadcast_dims[0] == 3); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index e2b550fc022..931f5898002 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -41,6 +41,8 @@ limitations under the License. namespace xla { +using ::tensorflow::gtl::FlatMap; +using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; @@ -394,8 +396,8 @@ Status GatherComputationsByAllocationType( // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - tensorflow::gtl::FlatSet thread_local_set; - tensorflow::gtl::FlatSet global_set; + FlatSet thread_local_set; + FlatSet global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); @@ -554,10 +556,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet* hlos_to_allocate, - const tensorflow::gtl::FlatSet& colocated_buffers, - const tensorflow::gtl::FlatSet& - colocated_allocations, + const FlatSet* hlos_to_allocate, + const FlatSet& colocated_buffers, + const FlatSet& colocated_allocations, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of // size. @@ -578,7 +579,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Generate a post order sort of instructions for sorting of the // LogicalBuffers. - tensorflow::gtl::FlatMap post_order_position; + FlatMap post_order_position; int position = 0; for (auto* instruction : computation->MakeInstructionPostOrder()) { post_order_position.emplace(instruction, position); @@ -590,7 +591,7 @@ Status BufferAssigner::AssignBuffersForComputation( const BufferLiveness& liveness = assignment->liveness(); const std::vector* sequential_order = liveness.hlo_ordering().SequentialOrder(*computation); - tensorflow::gtl::FlatSet unassigned_temp_buffers; + FlatSet unassigned_temp_buffers; // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers // first for simplicity. This means any previously created BufferAllocation is @@ -791,7 +792,7 @@ Status BufferAssigner::AssignBuffersForComputation( Status BufferAssigner::AssignBuffersWithSequentialOrdering( const std::vector& sequence, - const tensorflow::gtl::FlatSet& buffers_to_assign, + const FlatSet& buffers_to_assign, const HloComputation& computation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic // that seems to give the best results is lazy-best-fit, with all runs of @@ -881,40 +882,137 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } +// Conceptually the same as AddSetToColocatedBufferSets, but specific to the +// colocated buffers for while instructions. 'colocated_set' contains the +// buffers for a single while instruction that must be colocated. The idea here +// is to apply a memory-saving heuristic for separate while instructions whose +// buffers are disjoint in liveness, by using the colocation mechanism to force +// buffer sharing. This often reduces memory for multi-layer RNNs. +// +// TODO(b/32491382): We should be able to remove this heuristic after we +// implement module-level liveness analysis, which would let us directly detect +// buffer sharing opportunities between the while instruction buffer and the +// buffers from the predicate and body computation, as well as sharing across +// different while instructions. +void BufferAssigner::AddWhileSetToColocatedBufferSets( + const std::vector& colocated_set, + const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const HloComputation& computation, const BufferLiveness& buffer_liveness, + std::vector* colocated_buffer_sets) { + CHECK(!colocated_set.empty()); + + // Parallel while loops cannot safely share colocated buffer sets. + if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) { + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + return; + } + + // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets + // are added in postorder over computations and instructions. + const int64 init_buffer_size = buffer_size_(*while_init_buffer); + for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) { + const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i]; + + // Skip predecessor sets not associated with while loops. + if (std::all_of(predecessor_set.begin(), predecessor_set.end(), + [](const LogicalBuffer* buffer) { + return buffer->instruction()->opcode() != + HloOpcode::kWhile; + })) { + continue; + } + + // Skip predecessor sets already associated with 'while_hlo'. + if (std::any_of(predecessor_set.begin(), predecessor_set.end(), + [&while_hlo](const LogicalBuffer* buffer) { + return buffer->instruction() == while_hlo; + })) { + continue; + } + + // Build vector of predecessor while result buffers. + std::vector predecessor_while_buffers; + for (const LogicalBuffer* buffer : predecessor_set) { + if (buffer->instruction()->opcode() == HloOpcode::kWhile && + buffer_size_(*buffer) == init_buffer_size && + buffer->instruction()->parent() == &computation) { + predecessor_while_buffers.push_back(buffer); + } + } + if (predecessor_while_buffers.empty()) { + continue; + } + + // Skip predecessor set if the live range of any predecessor buffers + // overlaps with 'while_init_buffer'. Note that tuple element buffer + // forwarding can cause the same buffer to appear on both sides of the + // interference comparison below. + if (std::any_of( + predecessor_while_buffers.begin(), predecessor_while_buffers.end(), + [while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) { + return while_init_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_init_buffer, *buffer); + })) { + continue; + } + + // All our checks have passed; merge 'predecessor_set' with 'colocated_set', + // and add the merged set to 'colocated_buffer_sets'. This forces the + // colocation of buffers across different while instructions. + FlatSet unique; + unique.insert(predecessor_set.begin(), predecessor_set.end()); + unique.insert(colocated_set.begin(), colocated_set.end()); + std::vector merged_set(unique.begin(), unique.end()); + AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets); + return; + } + + // Failed to merge into predecessor set; add 'colocated_set' as-is. + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); +} + namespace { + // Checks that points-to set of 'instruction' is unambiguous and distinct // (ensured by CopyInsertion), then adds the buffer from the points-to set at // 'index' to 'colocated_set'. -void AddBufferToColocatedSet(const HloInstruction* instruction, - const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { // CopyInsertion ensures root points-to set is unambiguous and distinct. const auto& points_to = points_to_analysis.GetPointsToSet(instruction); CHECK(!points_to.IsAmbiguous()); CHECK(points_to.IsDistinct()); colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); } + } // namespace // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile and kCall). void BufferAssigner::BuildColocatedBufferSets( - const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloModule* module, const BufferLiveness& buffer_liveness, std::vector* colocated_buffer_sets) { - for (auto& computation : module->computations()) { - for (auto& instruction : computation->instructions()) { + const TuplePointsToAnalysis& points_to_analysis = + buffer_liveness.points_to_analysis(); + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + for (const HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { const HloOpcode opcode = instruction->opcode(); if (opcode == HloOpcode::kWhile) { - HloInstruction* while_hlo = instruction.get(); + const HloInstruction* while_hlo = instruction; TF_CHECK_OK(ShapeUtil::ForEachSubshape( while_hlo->shape(), - [this, while_hlo, &points_to_analysis, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { + [this, while_hlo, &points_to_analysis, &buffer_liveness, + computation, colocated_buffer_sets](const Shape& /*subshape*/, + const ShapeIndex& index) { std::vector colocated_set; // Add while.init. - AddBufferToColocatedSet(while_hlo->operand(0), index, - points_to_analysis, &colocated_set); + auto* init_buffer = + AddBufferToColocatedSet(while_hlo->operand(0), index, + points_to_analysis, &colocated_set); // Add while.result. AddBufferToColocatedSet(while_hlo, index, points_to_analysis, &colocated_set); @@ -930,12 +1028,15 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet( while_hlo->while_body()->root_instruction(), index, points_to_analysis, &colocated_set); - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + AddWhileSetToColocatedBufferSets( + colocated_set, init_buffer, while_hlo, *computation, + buffer_liveness, colocated_buffer_sets); return Status::OK(); })); } else if (opcode == HloOpcode::kCall) { - HloInstruction* call_hlo = instruction.get(); - HloInstruction* root_hlo = call_hlo->to_apply()->root_instruction(); + const HloInstruction* call_hlo = instruction; + const HloInstruction* root_hlo = + call_hlo->to_apply()->root_instruction(); TF_CHECK_OK(ShapeUtil::ForEachSubshape( call_hlo->shape(), [this, call_hlo, root_hlo, &points_to_analysis, @@ -961,8 +1062,8 @@ void BufferAssigner::BuildColocatedBufferSets( void BufferAssigner::AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - tensorflow::gtl::FlatSet* colocated_buffers, - tensorflow::gtl::FlatSet* colocated_allocations) { + FlatSet* colocated_buffers, + FlatSet* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; for (const LogicalBuffer* buffer : colocated_buffer_set) { @@ -1008,9 +1109,9 @@ StatusOr> BufferAssigner::CreateAssignment( // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to // AssignBuffersForComputation for fast membership testing. - std::unique_ptr> hlo_set; + std::unique_ptr> hlo_set; if (hlos_to_allocate != nullptr) { - hlo_set = MakeUnique>( + hlo_set = MakeUnique>( hlos_to_allocate->begin(), hlos_to_allocate->end()); } @@ -1022,11 +1123,11 @@ StatusOr> BufferAssigner::CreateAssignment( // Once b/32491382 enables module-level liveness analysis, we may be able // to assign colocated buffers (or at least reuse their allocation for // buffers outside of the set) in AssignBuffersForComputation. - tensorflow::gtl::FlatSet colocated_buffers; - tensorflow::gtl::FlatSet colocated_allocations; + FlatSet colocated_buffers; + FlatSet colocated_allocations; if (colocate_related_buffers_) { std::vector colocated_buffer_sets; - BuildColocatedBufferSets(module, assignment->points_to_analysis(), + BuildColocatedBufferSets(module, assignment->liveness(), &colocated_buffer_sets); AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), &colocated_buffers, &colocated_allocations); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index b82acb19b34..ec1375e24d6 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -465,7 +465,7 @@ class BufferAssigner { // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' // which should be colocated in the same buffer allocation. void BuildColocatedBufferSets( - const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloModule* module, const BufferLiveness& buffer_liveness, std::vector* colocated_buffer_sets); // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the @@ -482,6 +482,14 @@ class BufferAssigner { const std::vector& colocated_set, std::vector* colocated_buffer_sets); + // Conceptually the same as AddSetToColocatedBufferSets, but specific to the + // colocated buffers for while instructions. + void AddWhileSetToColocatedBufferSets( + const std::vector& colocated_set, + const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const HloComputation& computation, const BufferLiveness& buffer_liveness, + std::vector* colocated_buffer_sets); + const HloModule* module_; // Function which returns the buffer size for a given logical buffer (shape). diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index bb7342d5081..f6637d60986 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -23,10 +23,12 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -1245,6 +1247,163 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { } } -} // namespace +class WhileBufferAssignmentTest : public HloTestBase { + protected: + std::unique_ptr BuildWhileConditionComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto ten = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); + return builder.Build(); + } + std::unique_ptr BuildWhileBodyComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto input = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0)); + auto weights = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + auto output = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kMultiply, input, weights)); + builder.AddInstruction( + HloInstruction::CreateTuple({input, weights, output})); + return builder.Build(); + } + + void RunCopyInsertion(HloModule* module) { + CopyInsertion copy_insertion; + EXPECT_IS_OK(copy_insertion.Run(module).status()); + } + + std::unique_ptr RunBufferAssignment(HloModule* module, + int64 alignment = 1) { + auto sequence = + CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + return BufferAssigner::Run( + module, MakeUnique(module, sequence), + ByteSizeOf, alignment) + .ConsumeValueOrDie(); + } + + static int64 ByteSizeOf(const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*)); + } + + Shape data_shape_ = ShapeUtil::MakeShape(F32, {4}); + Shape loop_state_shape_ = + ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_}); +}; + +TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + auto weights1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, data_shape_, "weights1")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + + auto cond1 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body1 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + auto input1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input1, weights1, output1})); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + + module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + + // While instruction 'while0' has no predecessor while instructions with + // which to share allocations. + + // While instruction 'while1' can share allocations with the following + // buffers: + // *) while0[2], while1[0] + // *) while0[1], while1[1] + EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie()); + EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); +} + +TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + + auto cond1 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body1 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output1})); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + + module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie()); + EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); +} + +} // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 616b239a931..ceb0cdaa316 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -165,4 +165,17 @@ bool HloOpcodeIsComparison(HloOpcode opcode) { } } +bool HloOpcodeIsVariadic(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kCall: + case HloOpcode::kConcatenate: + case HloOpcode::kFusion: + case HloOpcode::kMap: + case HloOpcode::kTuple: + return true; + default: + return false; + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 978ed5e79b9..e2cdbfdfa7a 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -104,6 +104,9 @@ inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { // Returns true iff the given opcode is a comparison operation. bool HloOpcodeIsComparison(HloOpcode opcode); +// Returns true iff the given opcode has variadic operands. +bool HloOpcodeIsVariadic(HloOpcode opcode); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 6e3c9830712..eb7fe467b32 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -40,6 +40,8 @@ void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module, } // namespace StatusOr HloPassPipeline::Run(HloModule* module) { + run_called_ = true; + legacy_flags::HloPassPipelineFlags* flags = legacy_flags::GetHloPassPipelineFlags(); std::vector tmp = diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index a8c2d518730..682c4b952df 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -47,6 +47,7 @@ class HloPassPipeline : public HloPassInterface { // Returns a reference to the added pass. template T& AddPass(Args&&... args) { + CHECK(!run_called_) << "AddPass cannot be called after Run"; auto pass = new T(std::forward(args)...); passes_.push_back(std::unique_ptr(pass)); return *pass; @@ -57,6 +58,7 @@ class HloPassPipeline : public HloPassInterface { // (it is required to always return "false" from its Run() method). template T& AddInvariantChecker(Args&&... args) { + CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run"; auto pass = new T(std::forward(args)...); invariant_checkers_.push_back(std::unique_ptr(pass)); return *pass; @@ -70,6 +72,7 @@ class HloPassPipeline : public HloPassInterface { Compiler::HloDumper dumper_; std::vector> passes_; std::vector> invariant_checkers_; + bool run_called_ = false; TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline); }; diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index caaf56a5516..1f625ae0e2b 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -101,12 +101,12 @@ std::vector> GetAllUsesOfInstructionAtIndex( } // namespace // User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following two qualifications: -// *) Is element-wise. +// and layout, and 'user' meets one of the following qualifications: +// *) Is element-wise. Or... // *) Is a loop fusion instruction where the only use of 'operand' at 'index' // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. -// *) Use of 'operand' is DynamicUpdateSlice at operand index 0. +// at operand 0. Or... +// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, @@ -144,7 +144,8 @@ bool CanShareOperandBufferWithUser( break; } return false; - } else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { + } else if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. std::vector operand_indices = user->OperandIndices(operand); diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 2ff71d6f3c8..079b59265ba 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -185,5 +185,73 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { *points_to_analysis_)); } +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = MakeUnique(TestName()); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 1796a732e54..16d4282466c 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -265,6 +265,37 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); } +TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { + auto builder = HloComputation::Builder(TestName()); + Array3D input_vals(2, 3, 4); + input_vals.FillRandom(1.0); + + Array4D expected(2, 3, 4, 5); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 4; ++k) { + for (int m = 0; m < 5; ++m) { + expected(i, j, k, m) = input_vals(i, j, k); + } + } + } + } + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR3FromArray3D(input_vals))); + + // Broadcast vector in dimensions 2 and 3. + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 34fce21758b..d00a3175344 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -211,9 +211,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); } XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); } XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); } XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); } -XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); } XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); } XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); } +XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); } XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); } XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); } XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); } @@ -221,6 +221,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); } XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) { RunR1ToR0Test(16 * 1024 + 1); } +XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); } +XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); } +XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); } XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); } XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); } diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 46eab7f02bb..ab598b8edd7 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -176,6 +176,52 @@ cc_binary( ], ) +cc_library( + name = "hlo_tfgraph_builder", + srcs = ["hlo_tfgraph_builder.cc"], + hdrs = ["hlo_tfgraph_builder.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "hlo_tfgraph_builder_test", + srcs = ["hlo_tfgraph_builder_test.cc"], + deps = [ + ":hlo_tfgraph_builder", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_binary( + name = "dumped_computation_to_tf_graphdef", + srcs = ["dumped_computation_to_tf_graphdef.cc"], + deps = [ + ":hlo_tfgraph_builder", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc new file mode 100644 index 00000000000..1aa769ee5a0 --- /dev/null +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: dumped_computation_to_tf_graph \ +// --output_dir=/tmp/graphs/ some_binary_snapshot_proto* +// +// Dumps a tensorflow GraphDef in text format for a snapshot computation. The +// dumped graph is an HLO computation with HLO instructions as nodes and can be +// visualized on Tensorboard. Upload the dumped files on Tensorboard. +// +// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// ServiceInterface::SnapshotComputation to disk. + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::Env; +using tensorflow::io::JoinPath; +using tensorflow::strings::StrAppend; + +namespace xla { +namespace tools { +namespace { + +// Dumps all computations in the module to the given directory. +void DumpTfGraph(const HloModule& module, const string& directory_path) { + Env* env = Env::Default(); + TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); + string fname = module.name(); + std::replace(fname.begin(), fname.end(), '/', '_'); + // Since the file name will be used as the top-level scope name, clean it up + // to make it a valid scope name. + CleanNodeName(&fname); + StrAppend(&fname, ".pbtxt"); + string path = JoinPath(directory_path, fname); + HloTfGraphBuilder builder; + TF_CHECK_OK(builder.AddComputation(*module.entry_computation())); + std::cout << "Dumping " << module.name() << " to " << path << std::endl; + TF_CHECK_OK(WriteTextProto(env, path, builder.GetGraphDef())); +} + +} // namespace + +void RealMain(tensorflow::gtl::ArraySlice args, + const string& output_dir) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + // To avoid adding a new flag, use local service and lower the computations + // locally. + LocalService* local_service = + ClientLibrary::GetXlaService(client->platform()); + // Build HloModule for each Computation and dump to file. + for (char* arg : args) { + SessionModule session_module; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, + &session_module)); + auto computation_status = client->LoadSnapshot(session_module); + if (!computation_status.ok()) { + fprintf(stderr, "could not load snapshot for %s: %s\n", arg, + computation_status.status().ToString().c_str()); + continue; + } + Computation computation = computation_status.ConsumeValueOrDie(); + + StatusOr user_computation_status = + local_service->computation_tracker().Resolve(computation.handle()); + if (!user_computation_status.ok()) { + fprintf(stderr, + "failed to resolve computation to UserComputation %s: %s\n", arg, + user_computation_status.status().ToString().c_str()); + continue; + } + + auto* user_computation = user_computation_status.ValueOrDie(); + StatusOr> module_status = + local_service->computation_tracker().BuildHloModule( + user_computation->GetVersionedHandle()); + + if (!module_status.ok()) { + fprintf(stderr, "failed to build HloModule %s: %s\n", arg, + module_status.status().ToString().c_str()); + continue; + } + + DumpTfGraph(*module_status.ValueOrDie(), output_dir); + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + string output_dir = ""; + const std::vector flag_list = { + tensorflow::Flag("output_dir", &output_dir, + "Directory to write GraphDef data to."), + }; + + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_ok || output_dir.empty()) { + LOG(QFATAL) << usage; + } + tensorflow::port::InitMain(argv[0], &argc, &argv); + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args, output_dir); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc new file mode 100644 index 00000000000..fe835a20c4b --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +LIcensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::tensorflow::TensorShapeProto; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; +using ::tensorflow::str_util::Join; + +namespace xla { +namespace tools { +namespace { + +string GetOpDefName(const HloInstruction* instruction) { + string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); + tensorflow::str_util::TitlecaseString(&name, "-"); + name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); + + if (instruction->opcode() == HloOpcode::kFusion) { + string fusion_name = ToString(instruction->fusion_kind()); + StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + } + return name; +} + +TensorShapeProto GetTensorShape(const HloInstruction* instruction) { + TensorShapeProto tensor_shape; + const Shape& shape = instruction->shape(); + for (auto dim : shape.dimensions()) { + tensor_shape.add_dim()->set_size(dim); + } + return tensor_shape; +} + +} // namespace + +void CleanNodeName(string* name) { + name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); + const string chars_to_replace = "<>[]"; + auto pred = [&](char c) { + return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != + chars_to_replace.end(); + }; + std::replace_if(name->begin(), name->end(), pred, '_'); +} + +Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { + LOG(INFO) << "Adding computation " << computation.name(); + for (auto embedded : computation.MakeEmbeddedComputationsList()) { + LOG(INFO) << "Adding embedded computation " << embedded->name(); + for (auto& instruction : embedded->instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + } + } + for (auto& instruction : computation.instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + } + return Status::OK(); +} + +const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } + +const string& HloTfGraphBuilder::GetNodeNameForInstruction( + const HloInstruction* instruction) { + if (ContainsKey(instruction_to_node_name_, instruction)) { + return instruction_to_node_name_[instruction]; + } + // If an instruction is fused, put it in the subgraph of the fusion; + // otherwise, put it in the computation subgraph. + string node_name = + instruction->IsFused() + ? GetNodeNameForInstruction(instruction->fusion_instruction()) + : instruction->parent()->name(); + string instruction_name = instruction->name(); + if (instruction->opcode() == HloOpcode::kParameter) { + StrAppend(&instruction_name, ".", instruction->parameter_number()); + } + StrAppend(&node_name, "/", instruction_name); + CleanNodeName(&node_name); + auto ret = + instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); + CHECK(ret.second); + return ret.first->second; +} + +void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, + NodeDef* node_def) const { + auto& attrs = *node_def->mutable_attr(); + + // Set the number of arguments for instructions that have variadic operands. + if (HloOpcodeIsVariadic(instruction->opcode())) { + tensorflow::AttrValue attr_value; + attr_value.set_i(instruction->operands().size()); + attrs["arg_num"] = attr_value; + } + + // Set the node type. + attrs["type"].set_s( + xla::PrimitiveType_Name(instruction->shape().element_type())); + + // Set the shape of the output tensor. "_output_shapes" is a special attribute + // name used by Tensorboard for shapes of output tensors. + tensorflow::AttrValue shapes; + *shapes.mutable_list()->add_shape() = GetTensorShape(instruction); + attrs["_output_shapes"] = shapes; + + // Set the layout. + if (LayoutUtil::HasLayout(instruction->shape())) { + string layout_string; + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuples, emit the full shape because the layout of a tuple is not + // represented in a single Layout field. + layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); + } else { + layout_string = StrCat( + "{", Join(instruction->shape().layout().minor_to_major(), ","), "}"); + } + attrs["layout"].set_s(layout_string); + } + + // Set op-specific attributes. + switch (instruction->opcode()) { + case HloOpcode::kConcatenate: + case HloOpcode::kBroadcast: + case HloOpcode::kReduce: + case HloOpcode::kReverse: + case HloOpcode::kTranspose: + for (auto dim : instruction->dimensions()) { + attrs["dims"].mutable_list()->add_i(dim); + } + break; + case HloOpcode::kGetTupleElement: + attrs["index"].set_i(instruction->tuple_index()); + break; + case HloOpcode::kRng: + attrs["dist"].set_s( + RandomDistribution_Name(instruction->random_distribution())); + break; + case HloOpcode::kConstant: + if (ShapeUtil::IsScalar(instruction->shape())) { + attrs["value"].set_s( + LiteralUtil::GetAsString(instruction->literal(), {})); + } + break; + case HloOpcode::kCustomCall: + attrs["custom_call_target"].set_s(instruction->custom_call_target()); + break; + default: + break; + } +} + +Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { + if (!visited_instructions_.insert(instruction).second) { + // Skip instructions that have already been added. + return Status::OK(); + } + + NodeDef* node_def = graph_def_.add_node(); + node_def->set_name(GetNodeNameForInstruction(instruction)); + node_def->set_op(GetOpDefName(instruction)); + SetNodeAttrs(instruction, node_def); + if (instruction->opcode() == HloOpcode::kFusion) { + for (auto& fused_instruction : instruction->fused_instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get())); + } + } + // Add all edges including control edges. + for (unsigned i = 0; i < instruction->operands().size(); ++i) { + *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); + } + // Called computations are control dependencies. + for (const auto* called_computation : instruction->called_computations()) { + *node_def->add_input() = StrCat( + "^", GetNodeNameForInstruction(called_computation->root_instruction())); + } + return Status::OK(); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h new file mode 100644 index 00000000000..3052eae113c --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph.h" + +namespace xla { +namespace tools { + +// This constructs a tensorflow graph for HLO computations. +class HloTfGraphBuilder { + public: + // Adds a computation to the graph. + Status AddComputation(const HloComputation& computation); + + const tensorflow::GraphDef& GetGraphDef() const; + + private: + // Gets the node name of an instruction. The node name is hierarchical. For + // example, if an instruction is fused, it will be put in a subgraph of the + // fusion instruction. + const string& GetNodeNameForInstruction(const HloInstruction* instruction); + + void SetNodeAttrs(const HloInstruction* instruction, + tensorflow::NodeDef* node_def) const; + + Status AddInstruction(const HloInstruction* instruction); + + tensorflow::GraphDef graph_def_; + // This records instructions that have been visited. + std::unordered_set visited_instructions_; + // A cache that maps instruction to the node name. + std::unordered_map instruction_to_node_name_; +}; + +// Cleans the node name to make it a valid name in a tensorflow graph. +void CleanNodeName(string* name); + +} // namespace tools +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc new file mode 100644 index 00000000000..626bcc6d856 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc @@ -0,0 +1,154 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace tools { +namespace { + +using ::tensorflow::GraphDef; + +class HloTfGraphBuilderTest : public HloTestBase { + protected: + HloTfGraphBuilderTest() {} + HloTfGraphBuilder generator_; + + // Create a computation which takes a scalar and returns its negation. + std::unique_ptr CreateNegateComputation() { + auto builder = HloComputation::Builder("Negate"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); + return builder.Build(); + } + + // Creates a computation which calls map with the given computation. + std::unique_ptr CreateMapComputation( + HloComputation* map_computation) { + auto builder = HloComputation::Builder("Map"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map_computation)); + return builder.Build(); + } + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { + auto builder = HloComputation::Builder("Concatenate"); + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 4}), {param_1, param_2}, 1)); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + const auto &node = graph_def.node(2); + EXPECT_EQ(node.name(), "Concatenate/concatenate"); + + // Check dimensions. + auto dims_value = node.attr().find("dims"); + CHECK(dims_value != node.attr().end()); + EXPECT_EQ(dims_value->second.list().i_size(), 1); + EXPECT_EQ(dims_value->second.list().i(0), 1); + + // Check shapes. + auto shape_value = node.attr().find("_output_shapes"); + CHECK(shape_value != node.attr().end()); + EXPECT_EQ(shape_value->second.list().shape_size(), 1); + EXPECT_EQ(shape_value->second.list().shape(0).dim_size(), 2); + EXPECT_EQ(shape_value->second.list().shape(0).dim(0).size(), 2); + EXPECT_EQ(shape_value->second.list().shape(0).dim(1).size(), 4); +} + +TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { + auto builder = HloComputation::Builder("Const"); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 1); + const auto &node = graph_def.node(0); + auto value = node.attr().find("value"); + CHECK(value != node.attr().end()); + EXPECT_EQ(value->second.s(), "123"); + + auto type = node.attr().find("type"); + CHECK(type != node.attr().end()); + EXPECT_EQ(type->second.s(), "S32"); +} + +TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { + auto negate_computation = CreateNegateComputation(); + TF_CHECK_OK(generator_.AddComputation(*negate_computation)); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 2); + EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); + EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); + EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); + EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); + EXPECT_EQ(graph_def.node(1).input_size(), 1); + EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); +} + +TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { + auto builder = HloComputation::Builder("GE"); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "param1")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); + EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); + EXPECT_EQ(graph_def.node(2).input_size(), 2); + EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); + EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); +} + +TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { + // Create computations with a diamond-shaped callgraph. + auto negate_computation = CreateNegateComputation(); + auto map1_computation = CreateMapComputation(negate_computation.get()); + auto map2_computation = CreateMapComputation(negate_computation.get()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto map1 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); + auto map2 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); + auto computation = builder.Build(); + TF_CHECK_OK(generator_.AddComputation(*computation)); + EXPECT_GT(generator_.GetGraphDef().node_size(), 0); +} + +} // namespace +} // namespace tools +} // namespace xla diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py index 81e40dbe5ec..c7f185aab82 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py @@ -42,12 +42,10 @@ class StochasticTensorTest(test.TestCase): sigma2 = constant_op.constant([0.1, 0.2, 0.3]) prior_default = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma)) + distributions.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior_default.value_type, st.SampleValue)) prior_0 = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma), + distributions.Normal(loc=mu, scale=sigma), dist_value_type=st.SampleValue()) self.assertTrue(isinstance(prior_0.value_type, st.SampleValue)) @@ -55,8 +53,7 @@ class StochasticTensorTest(test.TestCase): prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior.value_type, st.SampleValue)) likelihood = st.StochasticTensor( - distributions.Normal( - loc=prior, scale=sigma2)) + distributions.Normal(loc=prior, scale=sigma2)) self.assertTrue(isinstance(likelihood.value_type, st.SampleValue)) coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) @@ -102,8 +99,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue()): prior_single = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma)) + distributions.Normal(loc=mu, scale=sigma)) prior_single_value = prior_single.value() self.assertEqual(prior_single_value.get_shape(), (2, 3)) @@ -113,8 +109,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue(1)): prior_single = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma)) + distributions.Normal(loc=mu, scale=sigma)) self.assertTrue(isinstance(prior_single.value_type, st.SampleValue)) prior_single_value = prior_single.value() @@ -125,8 +120,7 @@ class StochasticTensorTest(test.TestCase): with st.value_type(st.SampleValue(2)): prior_double = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma)) + distributions.Normal(loc=mu, scale=sigma)) prior_double_value = prior_double.value() self.assertEqual(prior_double_value.get_shape(), (2, 2, 3)) @@ -163,8 +157,7 @@ class StochasticTensorTest(test.TestCase): # With passed-in loss_fn. dt = st.StochasticTensor( - distributions.Normal( - loc=mu, scale=sigma), + distributions.Normal(loc=mu, scale=sigma), dist_value_type=st.MeanValue(stop_gradient=True), loss_fn=sge.get_score_function_with_constant_baseline( baseline=constant_op.constant(8.0))) @@ -199,8 +192,7 @@ class ObservedStochasticTensorTest(test.TestCase): sigma = constant_op.constant([1.1, 1.2, 1.3]) obs = array_ops.zeros((2, 3)) z = st.ObservedStochasticTensor( - distributions.Normal( - loc=mu, scale=sigma), value=obs) + distributions.Normal(loc=mu, scale=sigma), value=obs) [obs_val, z_val] = sess.run([obs, z.value()]) self.assertAllEqual(obs_val, z_val) @@ -212,15 +204,13 @@ class ObservedStochasticTensorTest(test.TestCase): sigma = array_ops.placeholder(dtypes.float32) obs = array_ops.placeholder(dtypes.float32) z = st.ObservedStochasticTensor( - distributions.Normal( - loc=mu, scale=sigma), value=obs) + distributions.Normal(loc=mu, scale=sigma), value=obs) mu2 = array_ops.placeholder(dtypes.float32, shape=[None]) sigma2 = array_ops.placeholder(dtypes.float32, shape=[None]) obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) z2 = st.ObservedStochasticTensor( - distributions.Normal( - loc=mu2, scale=sigma2), value=obs2) + distributions.Normal(loc=mu2, scale=sigma2), value=obs2) coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) self.assertEqual(coll, [z, z2]) @@ -231,22 +221,18 @@ class ObservedStochasticTensorTest(test.TestCase): self.assertRaises( ValueError, st.ObservedStochasticTensor, - distributions.Normal( - loc=mu, scale=sigma), + distributions.Normal(loc=mu, scale=sigma), value=array_ops.zeros((3,))) self.assertRaises( ValueError, st.ObservedStochasticTensor, - distributions.Normal( - loc=mu, scale=sigma), + distributions.Normal(loc=mu, scale=sigma), value=array_ops.zeros((3, 1))) self.assertRaises( ValueError, st.ObservedStochasticTensor, - distributions.Normal( - loc=mu, scale=sigma), - value=array_ops.zeros( - (1, 2), dtype=dtypes.int32)) + distributions.Normal(loc=mu, scale=sigma), + value=array_ops.zeros((1, 2), dtype=dtypes.int32)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 470b9edb793..e17197080ab 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -135,8 +135,9 @@ from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['ConditionalDistribution', - 'ConditionalTransformedDistribution', - 'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED'] +_allowed_symbols = [ + 'ConditionalDistribution', 'ConditionalTransformedDistribution', + 'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED' +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index 71460a17695..5d6e4d9197d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -488,9 +488,7 @@ class AffineBijectorTest(test.TestCase): shift=mu, scale_identity_multiplier=2., scale_perturb_diag=[2., 1], - scale_perturb_factor=[[2., 0], - [0., 0], - [0, 1]]) + scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = affine_lib.Affine(shift=mu, scale_diag=[10., 2, 3]) self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" @@ -526,9 +524,7 @@ class AffineBijectorTest(test.TestCase): shift=mu, scale_diag=[2., 3, 4], scale_perturb_diag=[2., 1], - scale_perturb_factor=[[2., 0], - [0., 0], - [0, 1]]) + scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = affine_lib.Affine(shift=mu, scale_diag=[10., 3, 5]) self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" @@ -561,17 +557,11 @@ class AffineBijectorTest(test.TestCase): # Corresponds to scale = [[10, 0, 0], [1, 3, 0], [2, 3, 5]] bijector = affine_lib.Affine( shift=mu, - scale_tril=[[2., 0, 0], - [1, 3, 0], - [2, 3, 4]], + scale_tril=[[2., 0, 0], [1, 3, 0], [2, 3, 4]], scale_perturb_diag=[2., 1], - scale_perturb_factor=[[2., 0], - [0., 0], - [0, 1]]) + scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = affine_lib.Affine( - shift=mu, scale_tril=[[10., 0, 0], - [1, 3, 0], - [2, 3, 5]]) + shift=mu, scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]]) self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index ecf068bf6b5..cb514e625ba 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -70,7 +70,8 @@ class ChainBijectorTest(test.TestCase): softmax_centered_lib.SoftmaxCentered( event_ndims=1, validate_args=True), softmax_centered_lib.SoftmaxCentered( - event_ndims=0, validate_args=True)]) + event_ndims=0, validate_args=True) + ]) x = tensor_shape.TensorShape([]) y = tensor_shape.TensorShape([2 + 1]) self.assertAllEqual(y, bijector.forward_event_shape(x)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py index e16f9dff225..40018de63f5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py @@ -36,17 +36,19 @@ class SigmoidBijectorTest(test.TestCase): y = special.expit(x) ildj = -np.log(y) - np.log1p(-y) self.assertAllClose( - y, sigmoid.Sigmoid().forward(x).eval(), - atol=0., rtol=1e-2) + y, sigmoid.Sigmoid().forward(x).eval(), atol=0., rtol=1e-2) self.assertAllClose( - x, sigmoid.Sigmoid().inverse(y).eval(), - atol=0., rtol=1e-4) + x, sigmoid.Sigmoid().inverse(y).eval(), atol=0., rtol=1e-4) self.assertAllClose( - ildj, sigmoid.Sigmoid().inverse_log_det_jacobian(y).eval(), - atol=0., rtol=1e-6) + ildj, + sigmoid.Sigmoid().inverse_log_det_jacobian(y).eval(), + atol=0., + rtol=1e-6) self.assertAllClose( - -ildj, sigmoid.Sigmoid().forward_log_det_jacobian(x).eval(), - atol=0., rtol=1e-4) + -ildj, + sigmoid.Sigmoid().forward_log_det_jacobian(x).eval(), + atol=0., + rtol=1e-4) def testScalarCongruency(self): with self.test_session(): diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 7fee2e1f3a1..e3f6ddd8c04 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -171,11 +171,12 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args) super(RelaxedBernoulli, self).__init__( - distribution=logistic.Logistic(self._logits / self._temperature, - 1. / self._temperature, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, - name=name + "/Logistic"), + distribution=logistic.Logistic( + self._logits / self._temperature, + 1. / self._temperature, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=name + "/Logistic"), bijector=sigmoid_lib.Sigmoid(validate_args=validate_args), validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py index d7c646c19a7..d1491387962 100644 --- a/tensorflow/contrib/keras/python/keras/backend.py +++ b/tensorflow/contrib/keras/python/keras/backend.py @@ -3614,7 +3614,7 @@ _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json')) if os.path.exists(_config_path): try: _config = json.load(open(_config_path)) - except json.decoder.JSONDecodeError: + except ValueError: _config = {} _floatx = _config.get('floatx', floatx()) assert _floatx in {'float16', 'float32', 'float64'} diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 0140f6d0d3e..13cabe6e043 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -379,7 +379,10 @@ def batch_norm(inputs, fused=False, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, - scope=None): + scope=None, + renorm=False, + renorm_clipping=None, + renorm_decay=0.99): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -446,6 +449,19 @@ def batch_norm(inputs, zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. scope: Optional scope for `variable_scope`. + renorm: Whether to use Batch Renormalization + (https://arxiv.org/abs/1702.03275). This adds extra variables during + training. The inference is the same for either value of this parameter. + renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to + scalar `Tensors` used to clip the renorm correction. The correction + `(r, d)` is used as `corrected_value = normalized_value * r + d`, with + `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, + dmax are set to inf, 0, inf, respectively. + renorm_decay: Momentum used to update the moving means and standard + deviations with renorm. Unlike `momentum`, this affects training + and should be neither too small (which would add noise) nor too large + (which would give stale estimates). Note that `decay` is still applied + to get the means and variances for inference. Returns: A `Tensor` representing the output of the operation. @@ -464,6 +480,8 @@ def batch_norm(inputs, if param_regularizers is not None: raise ValueError('Regularizers are not currently ' 'supported for fused batch norm.') + if renorm: + raise ValueError('Renorm is not supported for fused batch norm.') return _fused_batch_norm( inputs, decay=decay, @@ -524,6 +542,9 @@ def batch_norm(inputs, beta_regularizer=beta_regularizer, gamma_regularizer=gamma_regularizer, trainable=trainable, + renorm=renorm, + renorm_clipping=renorm_clipping, + renorm_momentum=renorm_decay, name=sc.name, _scope=sc, _reuse=reuse) @@ -551,6 +572,9 @@ def batch_norm(inputs, # Custom updates collections are not supported because the update logic # is different in this case, in particular w.r.t. "forced updates" and # update op reuse. + if renorm: + raise ValueError('renorm is not supported with batch_weights, ' + 'updates_collections or zero_debias_moving_mean') inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: @@ -1241,6 +1265,13 @@ def flatten(inputs, def _sparse_inner_flatten(inputs, new_rank): """Helper function for `inner_flatten`.""" + inputs_rank = inputs.dense_shape.get_shape().as_list()[0] + if inputs_rank < new_rank: + raise ValueError( + 'Inputs has rank less than new_rank. {} must have rank at least' + ' {}. Received rank {}, shape {}'.format(inputs, new_rank, inputs_rank, + inputs.get_shape())) + outer_dimensions = inputs.dense_shape[:new_rank - 1] inner_dimensions = inputs.dense_shape[new_rank - 1:] new_shape = array_ops.concat((outer_dimensions, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 2b170e92ba1..ee4ebf2c435 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1465,6 +1465,30 @@ class PartialFlattenTest(test.TestCase): flattened5 = _layers._inner_flatten(inputs, 5) self.assertEqual([2, None, 4, None, 30], flattened5.get_shape().as_list()) + def testDenseFlattenRankAssertion(self): + """Test `_inner_flatten` rank assertion for dense tensors.""" + shape = [2, 3] + new_rank = 3 + inputs = array_ops.placeholder(dtypes.int32) + inputs.set_shape(shape) + + with self.assertRaisesRegexp(ValueError, + 'inputs has rank less than new_rank'): + _layers._inner_flatten(inputs, new_rank) + + def testSparseFlattenRankAssertion(self): + """Test `_inner_flatten` rank assertion for sparse tensors.""" + shape = [2, 3] + new_rank = 3 + np.random.seed(10301) + random_ = np.random.rand(*shape) + indices, values, _ = _sparsify(random_) + inputs = sparse_tensor.SparseTensor(indices, values, shape) + + with self.assertRaisesRegexp(ValueError, + 'Inputs has rank less than new_rank'): + _layers._inner_flatten(inputs, new_rank) + class FCTest(test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index 01262ff5f81..fd50070dac5 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -27,7 +27,8 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.framework import dtypes -SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' +# CVDF mirror of http://yann.lecun.com/exdb/mnist/ +SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' def _read32(bytestream): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 107454dca1a..29ea692f8fb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -362,6 +362,11 @@ class BaseEstimator( self._config = config logging.info('Using config: %s', str(vars(self._config))) + if self._config.session_config is None: + self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + self._session_config = self._config.session_config + # Model directory. if (model_dir is not None) and (self._config.model_dir is not None): if model_dir != self._config.model_dir: @@ -829,7 +834,7 @@ class BaseEstimator( eval_ops=update_op, final_ops=eval_dict, hooks=hooks, - config=config_pb2.ConfigProto(allow_soft_placement=True)) + config=self._session_config) current_global_step = eval_results[global_step_key] _write_dict_to_summary(eval_dir, eval_results, current_global_step) @@ -864,7 +869,7 @@ class BaseEstimator( session_creator=monitored_session.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=infer_ops.scaffold, - config=config_pb2.ConfigProto(allow_soft_placement=True))) + config=self._session_config)) if not as_iterable: with mon_sess: if not mon_sess.should_stop(): @@ -976,7 +981,7 @@ class BaseEstimator( chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, - config=config_pb2.ConfigProto(allow_soft_placement=True) + config=self._session_config ) as mon_sess: loss = None while not mon_sess.should_stop(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index c0a39185493..c56741a4d13 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -53,12 +53,17 @@ class ModeKeys(object): EVAL = 'eval' INFER = 'infer' + @classmethod + def validate(cls, key): + if key not in (cls.TRAIN, cls.EVAL, cls.INFER): + raise ValueError('Invalid mode %s.' % key) + class ModelFnOps( collections.namedtuple('ModelFnOps', [ 'predictions', 'loss', 'train_op', 'eval_metric_ops', 'output_alternatives', 'training_chief_hooks', 'training_hooks', - 'scaffold' + 'scaffold', 'mode' ])): """Ops returned from a model_fn.""" @@ -119,6 +124,8 @@ class ModelFnOps( Raises: ValueError: If validation fails. """ + ModeKeys.validate(mode) + # Assert all ops are from the same graph. get_graph_from_inputs((predictions, loss, train_op)) @@ -183,14 +190,13 @@ class ModelFnOps( output_alternatives=output_alternatives, training_chief_hooks=training_chief_hooks, training_hooks=training_hooks, - scaffold=scaffold) + scaffold=scaffold, + mode=mode) - def estimator_spec(self, mode, default_serving_output_alternative_key=None): + def estimator_spec(self, default_serving_output_alternative_key=None): """Creates an equivalent `EstimatorSpec`. Args: - mode: One of `ModeKeys`. Specifies if this training, evaluation or - prediction. default_serving_output_alternative_key: Required for multiple heads. If you have multiple entries in `output_alternatives` dict (comparable to multiple heads), `EstimatorSpec` requires a default head that will be @@ -265,7 +271,7 @@ class ModelFnOps( return result return core_model_fn_lib.EstimatorSpec( - mode=mode, + mode=self.mode, predictions=self.predictions, loss=self.loss, train_op=self.train_op, diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py index 51b32359a33..4f76013a2a5 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn_test.py @@ -80,18 +80,20 @@ class ModelFnopsTest(test.TestCase): def testEstimatorSpec_except_export(self): predictions = self.create_predictions() - model_fn_ops = self.create_model_fn_ops(predictions, None) + model_fn_ops = self.create_model_fn_ops( + predictions, None, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) def testEstimatorSpec_export_regression_with_scores(self): predictions = self.create_predictions() output_alternatives = {"regression_head": ( constants.ProblemType.LINEAR_REGRESSION, predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -108,9 +110,10 @@ class ModelFnopsTest(test.TestCase): output_alternatives = {"regression_head": ( constants.ProblemType.LINEAR_REGRESSION, output_alternatives_predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -124,9 +127,10 @@ class ModelFnopsTest(test.TestCase): predictions = self.create_predictions() output_alternatives = {"classification_head": ( constants.ProblemType.CLASSIFICATION, predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -145,9 +149,10 @@ class ModelFnopsTest(test.TestCase): del output_alternatives_predictions["scores"] output_alternatives = {"classification_head": ( constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -167,9 +172,10 @@ class ModelFnopsTest(test.TestCase): del output_alternatives_predictions["probabilities"] output_alternatives = {"classification_head": ( constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -187,9 +193,10 @@ class ModelFnopsTest(test.TestCase): del output_alternatives_predictions["classes"] output_alternatives = {"classification_head": ( constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -208,9 +215,10 @@ class ModelFnopsTest(test.TestCase): [1, 2, 3]) output_alternatives = {"classification_head": ( constants.ProblemType.CLASSIFICATION, output_alternatives_predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -226,9 +234,10 @@ class ModelFnopsTest(test.TestCase): predictions = self.create_predictions() output_alternatives = {"logistic_head": ( constants.ProblemType.LOGISTIC_REGRESSION, predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -245,9 +254,10 @@ class ModelFnopsTest(test.TestCase): output_alternatives = {"unspecified_head": ( constants.ProblemType.UNSPECIFIED, predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER) + estimator_spec = model_fn_ops.estimator_spec() self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): @@ -263,10 +273,10 @@ class ModelFnopsTest(test.TestCase): constants.ProblemType.LINEAR_REGRESSION, predictions), "classification_head": ( constants.ProblemType.CLASSIFICATION, predictions)} - model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives) + model_fn_ops = self.create_model_fn_ops( + predictions, output_alternatives, mode=model_fn.ModeKeys.INFER) - estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER, - "regression_head") + estimator_spec = model_fn_ops.estimator_spec("regression_head") self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec) with session.Session(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index bc7465bbc22..37ee814b620 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -214,7 +214,8 @@ class RunConfig(ClusterConfig): keep_checkpoint_max=5, keep_checkpoint_every_n_hours=10000, evaluation_master='', - model_dir=None): + model_dir=None, + session_config=None): """Constructor. Note that the superclass `ClusterConfig` may set properties like @@ -246,6 +247,9 @@ class RunConfig(ClusterConfig): evaluation_master: the master on which to perform evaluation. model_dir: directory where model parameters, graph etc are saved. If `None`, see `Estimator` about where the model will be saved. + session_config: a ConfigProto used to set session parameters, or None. + Note - using this argument, it is easy to provide settings which break + otherwise perfectly good models. Use with care. """ super(RunConfig, self).__init__( master=master, evaluation_master=evaluation_master) @@ -261,6 +265,7 @@ class RunConfig(ClusterConfig): self._tf_random_seed = tf_random_seed self._save_summary_steps = save_summary_steps self._save_checkpoints_secs = save_checkpoints_secs + self._session_config = session_config if save_checkpoints_secs == RunConfig._USE_DEFAULT: if save_checkpoints_steps is None: self._save_checkpoints_secs = 600 @@ -345,6 +350,10 @@ class RunConfig(ClusterConfig): def save_checkpoints_steps(self): return self._save_checkpoints_steps + @property + def session_config(self): + return self._session_config + @property def keep_checkpoint_max(self): return self._keep_checkpoint_max diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index cecc24c17d8..4f7c72c9dda 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -118,7 +118,8 @@ class Experiment(object): occur if no new snapshot is available, hence, this is the minimum. delay_workers_by_global_step: if `True` delays training workers based on global step instead of time. - export_strategies: A list of `ExportStrategy`s, or a single one, or None. + export_strategies: Iterable of `ExportStrategy`s, or a single one, or + `None`. train_steps_per_iteration: (applies only to continuous_train_and_eval). Perform this many (integer) number of train steps for each training-evaluation iteration. With a small value, the model will be @@ -184,16 +185,19 @@ class Experiment(object): def eval_steps(self): return self._eval_steps - def _set_export_strategies(self, value): - if value is None: - self._export_strategies = [] - elif isinstance(value, list): - self._export_strategies = value[:] - elif isinstance(value, export_strategy.ExportStrategy): - self._export_strategies = [value] - else: - raise ValueError("`export_strategies` must be an ExportStrategy, " - "a list of ExportStrategies, or None.") + def _set_export_strategies(self, values): # pylint: disable=missing-docstring + export_strategies = [] + if values: + if isinstance(values, export_strategy.ExportStrategy): + export_strategies.append(values) + else: + for value in values: + if not isinstance(value, export_strategy.ExportStrategy): + raise ValueError("`export_strategies` must be an ExportStrategy," + " an iterable of ExportStrategy, or `None`," + " found %s." % value) + export_strategies.append(value) + self._export_strategies = tuple(export_strategies) def extend_train_hooks(self, additional_hooks): """Extends the hooks for training.""" diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index 00ed062b0aa..f9f95f8e675 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -484,6 +484,25 @@ class ExperimentTest(test.TestCase): self.assertAllEqual([noop_hook, another_noop_hook], ex._train_monitors) self.assertAllEqual([noop_hook], input_hooks) + def test_invalid_export_strategies(self): + for est in self._estimators_for_tests(): + with self.assertRaisesRegexp(ValueError, 'ExportStrategy'): + experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input', + train_steps=100, + eval_steps=100, + export_strategies='not_an_export_strategy') + with self.assertRaisesRegexp(ValueError, 'ExportStrategy'): + experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input', + train_steps=100, + eval_steps=100, + export_strategies=['not_an_export_srategy']) + def test_export_strategies_reset(self): for est in self._estimators_for_tests(): eval_metrics = 'eval_metrics' if not isinstance( @@ -498,7 +517,7 @@ class ExperimentTest(test.TestCase): eval_metrics=eval_metrics, train_steps=100, eval_steps=100, - export_strategies=[export_strategy_1]) + export_strategies=(export_strategy_1,)) ex.train_and_evaluate() self.assertEqual(1, est.export_count) @@ -728,7 +747,7 @@ class ExperimentTest(test.TestCase): est, train_input_fn='train_input', eval_input_fn='eval_input', - export_strategies=[exp_strategy]) + export_strategies=(exp_strategy,)) ex.test() self.assertEqual(1, est.fit_count) self.assertEqual(1, est.eval_count) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py index 4a70f00407e..c302c7725a4 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py @@ -131,4 +131,5 @@ def generator_input_fn(x, target = features.pop(target_key[0]) return features, target return features + return _generator_input_fn diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py index ae68e35c219..bc767ec18b1 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py @@ -35,7 +35,7 @@ from tensorflow.python.training import queue_runner_impl class GeneratorIoTest(test.TestCase): - + def testGeneratorInputFn(self): def generator(): diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 0f317b7bb04..9bdd3206b24 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -359,8 +359,9 @@ def _read_keyed_batch_examples_helper(file_pattern, # Check input parameters are given and reasonable. if (not queue_capacity) or (queue_capacity <= 0): raise ValueError('Invalid queue_capacity %s.' % queue_capacity) - if (batch_size is None) or ((not isinstance(batch_size, ops.Tensor)) and - (batch_size <= 0 or batch_size > queue_capacity)): + if (batch_size is None) or ( + (not isinstance(batch_size, ops.Tensor)) and + (batch_size <= 0 or batch_size >= queue_capacity)): raise ValueError('Invalid batch_size %s, with queue_capacity %s.' % (batch_size, queue_capacity)) if (read_batch_size is None) or ( diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 83643689e1a..542aaabc953 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -112,6 +112,18 @@ class GraphIOTest(test.TestCase): queue_capacity=queue_capacity, num_threads=num_threads, name=name) + self.assertRaisesRegexp( + ValueError, + "Invalid batch_size", + graph_io.read_batch_examples, + _VALID_FILE_PATTERN, + default_batch_size, + io_ops.TFRecordReader, + False, + num_epochs=None, + queue_capacity=default_batch_size, + num_threads=num_threads, + name=name) self.assertRaisesRegexp( ValueError, "Invalid queue_capacity", @@ -356,7 +368,7 @@ class GraphIOTest(test.TestCase): ] filename = self._create_temp_file("".join(json_lines)) batch_size = 10000 - queue_capacity = 10000 + queue_capacity = 100000 name = "my_large_batch" features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)} diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py index d3270dcc162..9c63096d0ee 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py @@ -29,10 +29,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test -def tearDownModule(): - gfile.DeleteRecursively(test.get_temp_dir()) - - class GcTest(test_util.TensorFlowTestCase): def testLargestExportVersions(self): diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index 9b196e2cf50..34a293f80bb 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -30,7 +30,7 @@ cuda_py_tests( cuda_py_tests( name = "linear_operator_addition_test", - size = "medium", + size = "small", srcs = ["python/kernel_tests/linear_operator_addition_test.py"], additional_deps = [ ":linalg_py", @@ -43,7 +43,6 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], - shard_count = 5, ) cuda_py_tests( @@ -61,7 +60,6 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], - shard_count = 5, ) cuda_py_tests( @@ -79,7 +77,6 @@ cuda_py_tests( "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], - shard_count = 5, ) cuda_py_tests( @@ -96,7 +93,6 @@ cuda_py_tests( "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], - shard_count = 5, ) cuda_py_tests( @@ -112,7 +108,6 @@ cuda_py_tests( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], - shard_count = 5, ) cuda_py_tests( @@ -128,7 +123,6 @@ cuda_py_tests( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], - shard_count = 5, ) cuda_py_tests( @@ -144,12 +138,11 @@ cuda_py_tests( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], - shard_count = 5, ) cuda_py_tests( name = "linear_operator_util_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/linear_operator_util_test.py"], additional_deps = [ ":linalg_py", @@ -160,7 +153,6 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], - shard_count = 5, ) py_library( diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py index a06af336e71..f047f4b9787 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py @@ -229,6 +229,29 @@ class MatmulWithBroadcastTest(test.TestCase): self.assertAllEqual(expected, result) +class MatrixAdjointTest(test.TestCase): + + def testNonBatchMatrix(self): + a = [[1, 2, 3j], [4, 5, -6j]] # Shape (2, 3) + expected = [[1, 4], [2, 5], [-3j, 6j]] # Shape (3, 2) + with self.test_session(): + a_adj = linear_operator_util.matrix_adjoint(a) + self.assertEqual((3, 2), a_adj.get_shape()) + self.assertAllClose(expected, a_adj.eval()) + + def testBatchMatrix(self): + matrix_0 = [[1j, 2, 3], [4, 5, 6]] + matrix_0_a = [[-1j, 4], [2, 5], [3, 6]] + matrix_1 = [[11, 22, 33], [44, 55, 66j]] + matrix_1_a = [[11, 44], [22, 55], [33, -66j]] + batch_matrix = [matrix_0, matrix_1] # Shape (2, 2, 3) + expected_adj = [matrix_0_a, matrix_1_a] # Shape (2, 3, 2) + with self.test_session(): + matrix_adj = linear_operator_util.matrix_adjoint(batch_matrix) + self.assertEqual((2, 3, 2), matrix_adj.get_shape()) + self.assertAllEqual(expected_adj, matrix_adj.eval()) + + class DomainDimensionStubOperator(object): def __init__(self, domain_dimension): diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py index a52a235677f..9f8cb231693 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py @@ -289,6 +289,53 @@ def matmul_with_broadcast(a, b_is_sparse=b_is_sparse) +def matrix_adjoint(a, name="matrix_adjoint"): + """Transposes last two dimensions of tensor `a`, and takes complex conjugate. + + If `a` is real valued, the result is equivalent to `matrix_transpose`. + + For example: + + ```python + # Matrix with no batch dimension. + # 'x' is [[1 2 3j] + # [4 5 -6j]] + tf.matrix_adjoint(x) ==> [[1 4] + [2 5] + [-3j 6j]] + + # Matrix with two batch dimensions. + # x.shape is [1, 2, 3, 4] + # tf.matrix_adjoint(x) is shape [1, 2, 4, 3] + ``` + + Note that `tf.matmul` provides kwargs allowing for adjoint of arguments. This + is done with minimal cost, and is preferable to using this function. E.g. + + ``` + # Good! Adjoint is taken at minimal additional cost. + tf.matmul(matrix, b, adjoint_b=True) + + # Inefficient! + tf.matmul(matrix, tf.matrix_adjoint(b)) + ``` + + Args: + a: A `Tensor` with `rank >= 2`. + name: A name for the operation (optional). + + Returns: + A batch matrix `Tensor` with same `dtype` as `a`. + + Raises: + ValueError: If `a` is determined statically to have `rank < 2`. + """ + with ops.name_scope(name, values=[a]): + a = ops.convert_to_tensor(a, name="a") + a_transpose = array_ops.matrix_transpose(a) + return math_ops.conj(a_transpose) + + def shape_tensor(shape, name=None): """Convert Tensor using default type, unless empty list or tuple.""" # Works just like random_ops._ShapeTensor. diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py new file mode 100644 index 00000000000..2ad0fd53109 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -0,0 +1,236 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A decoder that performs beam search. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.rnn import core_rnn_cell +from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.layers import base as layers_base +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.util import nest + + +__all__ = [ + "BeamSearchDecoderOutput", + "BeamSearchDecoderState", + "BeamSearchDecoder", +] + + +class BeamSearchDecoderOutput( + collections.namedtuple("BeamSearchDecoderOutput", ("rnn_output",))): + pass + + +class BeamSearchDecoderState( + collections.namedtuple("BeamSearchDecoderState", + ("cell_state", "log_prob", "beam_ids"))): + pass + + +class BeamSearchDecoder(decoder.Decoder): + """BeamSearch sampling decoder.""" + + def __init__(self, cell, embedding, start_tokens, end_token, + initial_state, beam_width, output_layer=None): + """Initialize BeamSearchDecoder. + + Args: + cell: An `RNNCell` instance. + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + beam_width: Python integer, the number of beams + output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., + `tf.layers.Dense`. Optional layer to apply to the RNN output prior + to storing the result or sampling. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.layers.Layer`. + ValueError: If `start_tokens` is not a vector or + `end_token` is not a scalar. + """ + if not isinstance(cell, core_rnn_cell.RNNCell): + raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) + if (output_layer is not None + and not isinstance(output_layer, layers_base._Layer)): # pylint: disable=protected-access + raise TypeError( + "output_layer must be a Layer, received: %s" % type(output_layer)) + self._cell = cell + self._initial_cell_state = initial_state + self._output_layer = output_layer + + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._batch_size = array_ops.size(start_tokens) + self._beam_width = beam_width + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + self._start_inputs = self._embedding_fn(self._start_tokens) + + @property + def batch_size(self): + return self._batch_size + + def _rnn_output_size(self): + size = self._cell.output_size + if self._output_layer is None: + return size + else: + # To use layer's compute_output_shape, we need to convert the + # RNNCell's output_size entries into shapes with an unknown + # batch size. We then pass this through the layer's + # compute_output_shape and read off all but the first (batch) + # dimensions to get the output size of the rnn with the layer + # applied to the top. + output_shape_with_unknown_batch = nest.map_structure( + lambda s: tensor_shape.TensorShape([None]).concatenate(s), + size) + layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access + output_shape_with_unknown_batch) + return nest.map_structure(lambda s: s[1:], layer_output_shape) + + @property + def output_size(self): + # Return the cell output and the id + prepend_beam_width = ( + lambda s: tensor_shape.TensorShape([self._beam_width]).concatenate(s)) + return BeamSearchDecoderOutput( + rnn_output=nest.map_structure( + prepend_beam_width, self._rnn_output_size())) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + rnn_output=nest.map_structure(lambda _: dtype, self._rnn_output_size())) + + def initialize(self, name=None): + """Initialize the decoder. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, first_inputs, initial_state)`. + """ + finished, first_inputs = self._finished, self._first_inputs + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=array_ops.zeros( + [self.batch_size, self.beam_width], + dtype=nest.flatten(self._initial_cell_state)[0].dtype), + beam_ids=tensor_array_ops.TensorArray( + size=0, dynamic_size=True, dtype=dtypes.int32, + clear_after_read=False)) + + return (finished, first_inputs, initial_state) + + def _merge_batch_beams(self, t): + t_static_shape = t.shape + t_shape = array_ops.shape(t) + static_batch_size = tensor_util.constant_value(self._batch_size) + batch_size_beam_width = ( + None if static_batch_size is None + else static_batch_size * self._beam_width) + reshaped_t = array_ops.reshape( + t, array_ops.concat( + ([self._batch_size * self._beam_width], t_shape[2:]), 0)) + reshaped_t.set_shape( + (tensor_shape.TensorShape([batch_size_beam_width]) + .concatenate(t_static_shape[2:]))) + return reshaped_t + + def _split_batch_beams(self, t): + t_static_shape = t.shape + t_shape = array_ops.shape(t) + reshaped_t = array_ops.reshape( + t, array_ops.concat( + ([self._batch_size, self._beam_width], t_shape[1:]), 0)) + static_batch_size = tensor_util.constant_value(self._batch_size) + reshaped_t.set_shape( + (tensor_shape.TensorShape([static_batch_size, self._beam_width]) + .concatenate(t_static_shape[1:]))) + return reshaped_t + + def step(self, time, inputs, state, name=None): + """Perform a decoding step. + + Args: + time: scalar `int32` tensor. + inputs: A (structure of) input tensors. + state: A (structure of) state tensors and TensorArrays. + name: Name scope for any created operations. + + Returns: + `(outputs, next_state, next_inputs, finished)`. + """ + with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): + cell_state = state.cell_state + inputs = nest.map_structure(self._merge_batch_beams, inputs) + cell_state = nest.map_structure(self._merge_batch_beams, cell_state) + cell_outputs, next_cell_state = self._cell(inputs, cell_state) + cell_outputs = nest.map_structure(self._split_batch_beams, cell_outputs) + next_cell_state = nest.map_structure(self._split_batch_beams, + next_cell_state) + + if self._output_layer is not None: + cell_outputs = self._output_layer(cell_outputs) + + # TODO(cinjon): Calculate next_log_probs, next_beam_ids, + # finished, next_inputs, final_cell_state via beam search + # via self._embedding + # .... + next_beam_ids, next_log_probs, final_cell_state, next_inputs, finished = ( + None, None, None, None, None) + + beam_ids = state.beam_ids.write(time, next_beam_ids) + + outputs = BeamSearchDecoderOutput(cell_outputs) + next_state = BeamSearchDecoderState( + log_probs=next_log_probs, + beam_ids=beam_ids, + cell_state=final_cell_state) + + return (outputs, next_state, next_inputs, finished) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index 1d2674af306..6338eb152e9 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest @@ -38,34 +39,7 @@ from tensorflow.python.util import nest __all__ = ["Decoder", "dynamic_decode"] -def _transpose_batch_time(x): - """Transpose the batch and time dimensions of a Tensor. - - Retains as much of the static shape information as possible. - - Args: - x: A tensor of rank 2 or higher. - - Returns: - x transposed along the first two dimensions. - - Raises: - ValueError: if `x` is rank 1 or lower. - """ - x_static_shape = x.get_shape() - if x_static_shape.ndims is not None and x_static_shape.ndims < 2: - raise ValueError( - "Expected input tensor %s to have rank at least 2, but saw shape: %s" % - (x, x_static_shape)) - x_rank = array_ops.rank(x) - x_t = array_ops.transpose( - x, array_ops.concat( - ([1, 0], math_ops.range(2, x_rank)), axis=0)) - x_t.set_shape( - tensor_shape.TensorShape([ - x_static_shape[1].value, x_static_shape[0].value - ]).concatenate(x_static_shape[2:])) - return x_t +_transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access @six.add_metaclass(abc.ABCMeta) diff --git a/tensorflow/contrib/session_bundle/gc_test.py b/tensorflow/contrib/session_bundle/gc_test.py index 1a8ee93cca4..8faf3ef3d4c 100644 --- a/tensorflow/contrib/session_bundle/gc_test.py +++ b/tensorflow/contrib/session_bundle/gc_test.py @@ -29,10 +29,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test -def tearDownModule(): - gfile.DeleteRecursively(test.get_temp_dir()) - - class GcTest(test_util.TensorFlowTestCase): def testLargestExportVersions(self): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ba761cd7c6f..afcc7891b61 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -272,6 +272,7 @@ cc_library( "lib/monitoring/sampler.h", "lib/random/distribution_sampler.h", "lib/random/philox_random.h", + "lib/random/random_distributions.h", "lib/random/simple_philox.h", "lib/strings/numbers.h", "lib/strings/str_util.h", @@ -383,6 +384,7 @@ tf_cuda_library( "util/bcast.h", "util/cuda_kernel_helper.h", "util/device_name_utils.h", + "util/env_var.h", "util/events_writer.h", "util/example_proto_fast_parsing.h", "util/example_proto_helper.h", @@ -1535,7 +1537,10 @@ cc_library( tf_cuda_library( name = "direct_session_internal", srcs = ["common_runtime/direct_session.cc"], - hdrs = ["common_runtime/direct_session.h"], + hdrs = [ + "common_runtime/direct_session.h", + "util/env_var.h", + ], copts = tf_copts(), cuda_deps = [ ":gpu_tracer", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index eda2be3e70f..768c2f6f753 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -57,6 +57,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/env_var.h" #if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_tracer.h" @@ -242,6 +243,13 @@ DirectSession::DirectSession(const SessionOptions& options, thread_pools_.push_back(GlobalThreadPool(options)); owns_thread_pools_ = false; } + // The default value of sync_on_finish will be flipped soon and this + // environment variable will be removed as well. + Status status = + ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + } // NOTE(mrry): We do not need to use a unique string for the session // handle, because DirectSession owns its devices. This may change // in future versions. @@ -448,7 +456,7 @@ Status DirectSession::Run(const RunOptions& run_options, if (LogMemory::IsEnabled()) { LogMemory::RecordStep(args.step_id, run_state_args.handle); } - args.sync_on_finish = true; + args.sync_on_finish = sync_on_finish_; const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE); @@ -632,7 +640,7 @@ Status DirectSession::PRunSetup(const std::vector& input_names, if (LogMemory::IsEnabled()) { LogMemory::RecordStep(args.step_id, run_state_args.handle); } - args.sync_on_finish = true; + args.sync_on_finish = sync_on_finish_; if (options_.config.graph_options().build_cost_model()) { run_state->collector.reset(new StepStatsCollector(nullptr)); diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 1495648631e..b9d22ac522c 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -247,6 +247,8 @@ class DirectSession : public Session { std::vector thread_pools_; bool owns_thread_pools_ = false; + // If true, blocks until device has finished all queued operations in a step. + bool sync_on_finish_ = true; // Schedules 'c' for execution on pool. void SchedClosure(thread::ThreadPool* pool, std::function c); diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index bbde0924c7f..f23f9361eb0 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -64,6 +64,10 @@ class ShapeRefiner { return it->second.get(); } + // Getters and setters for graph_def_version_. + int32 graph_def_version() { return graph_def_version_; } + void set_graph_def_version(int32 version) { graph_def_version_ = version; } + private: // Extracts the subgraph ending at 'node' that is statically // computable and inserts into 'out_graph'. If statically computable, @@ -100,7 +104,7 @@ class ShapeRefiner { const Node* node, int dst_idx, shape_inference::ShapeHandle* result); - const int graph_def_version_; + int32 graph_def_version_; const OpRegistryInterface* const ops_registry_; // The lifetime of the tensors are bound to the runner, so it should be the diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 537d489aae9..545ae867f62 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/worker.pb.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { @@ -48,6 +49,13 @@ GraphMgr::GraphMgr(const WorkerEnv* worker_env, RendezvousMgrInterface* rendezvous_mgr) : worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) { CHECK(rendezvous_mgr) << "Rendezvous mgr was null"; + // The default value of sync_on_finish will be flipped soon and this + // environment variable will be removed as well. + Status status = + ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + } } GraphMgr::~GraphMgr() { @@ -486,7 +494,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, args.cancellation_manager = cancellation_manager; args.stats_collector = collector; args.step_container = step_container; - args.sync_on_finish = true; + args.sync_on_finish = sync_on_finish_; if (LogMemory::IsEnabled()) { LogMemory::RecordStep(args.step_id, handle); } diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 4477a2764be..5f51d638578 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -137,6 +137,9 @@ class GraphMgr { mutex mu_; int64 next_id_ GUARDED_BY(mu_) = 0; + // If true, blocks until device has finished all queued operations in a step. + bool sync_on_finish_ = true; + // Table mapping graph handles to registered graphs. // // TODO(zhifengc): If the client does not call Deregister, we'll diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 4c87a453e2a..d5e6e293d6d 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -873,6 +873,13 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) { return Status::OK(); } +Status RandomShape(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); + return Status::OK(); +} + Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, ShapeHandle values_shape, ShapeHandle shape_shape) { // Validate ranks. diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 73509fb7fba..dc99e48adb9 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -199,6 +199,9 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c); // Tested by ops/math_ops_test.cc. Status BroadcastBinaryOpShapeFn(InferenceContext* c); +// Shape function for random operations. +Status RandomShape(shape_inference::InferenceContext* c); + // Validates the 3 component tensors of a sparse tensor have the proper // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 3626de58d62..3d913cdaf0c 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -125,6 +125,23 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start, } } +Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const { + if (!IsLegacyVector(shape.shape())) { + return errors::InvalidArgument( + "shape must be a vector of {int32,int64}, got shape ", + shape.shape().DebugString()); + } + if (shape.dtype() == DataType::DT_INT32) { + auto vec = shape.flat(); + return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); + } else if (shape.dtype() == DataType::DT_INT64) { + auto vec = shape.flat(); + return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); + } else { + return errors::InvalidArgument("shape must be a vector of {int32,int64}."); + } +} + void AsyncOpKernel::Compute(OpKernelContext* context) { Notification n; ComputeAsync(context, [&n]() { n.Notify(); }); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index d874b9087f1..91e6a98304d 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -151,6 +151,10 @@ class OpKernel { return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0); } + // Turn a shape Tensor into a TensorShape + // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars + Status MakeShape(const Tensor& shape, TensorShape* out) const; + private: const NodeDef def_; const DataTypeVector input_types_; diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 449d8f55f56..a990dc2f04d 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -239,8 +239,11 @@ string InferenceContext::DebugString() const { ProtoDebugString(node_def_)); } -Status InferenceContext::WithRank(ShapeHandle shape, int32 rank, +Status InferenceContext::WithRank(ShapeHandle shape, int64 rank, ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } const int32 existing = Rank(shape); if (existing == rank) { *out = shape; @@ -261,8 +264,11 @@ Status InferenceContext::WithRank(ShapeHandle shape, int32 rank, existing); } -Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank, +Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank, ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } const int32 existing = Rank(shape); if (existing >= rank) { *out = shape; @@ -276,8 +282,11 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int32 rank, " but is rank ", existing); } -Status InferenceContext::WithRankAtMost(ShapeHandle shape, int32 rank, +Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank, ShapeHandle* out) { + if (rank > kint32max) { + return errors::InvalidArgument("Rank cannot exceed kint32max"); + } const int32 existing = Rank(shape); if (existing == kUnknownRank) { return ReturnUnknownShape(out); @@ -470,12 +479,12 @@ Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2, return ReturnCreatedShape(dims, out); } -Status InferenceContext::ReplaceDim(ShapeHandle s, int dim_index_in, +Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in, DimensionHandle new_dim, ShapeHandle* out) { if (!RankKnown(s)) { return ReturnUnknownShape(out); } - int dim_index = dim_index_in; + int64 dim_index = dim_index_in; if (dim_index < 0) { dim_index = s->dims_.size() + dim_index; } @@ -510,7 +519,8 @@ ShapeHandle InferenceContext::UnknownShape() { return shape_manager_.UnknownShape(); } -ShapeHandle InferenceContext::UnknownShapeOfRank(int32 rank) { +ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) { + CHECK_LE(rank, kint32max) << "rank must be less than kint32max"; std::vector dims(rank); for (int32 i = 0; i < rank; ++i) { dims[i] = UnknownDim(); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index b7f1725c5f1..5e116884c67 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -194,7 +194,7 @@ class InferenceContext { return s; } - ShapeHandle input(int idx) const { return inputs_[idx]; } + ShapeHandle input(int64 idx) const { return inputs_[idx]; } Status input(StringPiece input_name, std::vector* output) const; int num_inputs() const { return inputs_.size(); } @@ -237,7 +237,7 @@ class InferenceContext { // idx can be negative for an offset from end of dimensions. // idx must be in the range [-1 * s.rank, s.rank). - DimensionHandle Dim(ShapeHandle s, int32 idx) { + DimensionHandle Dim(ShapeHandle s, int64 idx) { if (s->rank_ == kUnknownRank) { return UnknownDim(); } @@ -277,11 +277,11 @@ class InferenceContext { // the shape with asserted rank in <*out>. Otherwise return an error. // // Note that <*out> may be set to . - Status WithRank(ShapeHandle shape, int32 rank, + Status WithRank(ShapeHandle shape, int64 rank, ShapeHandle* out) TF_MUST_USE_RESULT; - Status WithRankAtLeast(ShapeHandle shape, int32 rank, + Status WithRankAtLeast(ShapeHandle shape, int64 rank, ShapeHandle* out) TF_MUST_USE_RESULT; - Status WithRankAtMost(ShapeHandle shape, int32 rank, + Status WithRankAtMost(ShapeHandle shape, int64 rank, ShapeHandle* out) TF_MUST_USE_RESULT; // If has value , or its value is unknown, returns OK and returns @@ -332,7 +332,7 @@ class InferenceContext { // Returns in the shape from replacing with // . - Status ReplaceDim(ShapeHandle s, int dim_index, DimensionHandle new_dim, + Status ReplaceDim(ShapeHandle s, int64 dim_index, DimensionHandle new_dim, ShapeHandle* out) TF_MUST_USE_RESULT; // Returns a new shape with the given dims. The returned value is owned by @@ -344,7 +344,7 @@ class InferenceContext { ShapeHandle UnknownShape(); // Returns a shape with specified rank but unknown dims. - ShapeHandle UnknownShapeOfRank(int32 rank); + ShapeHandle UnknownShapeOfRank(int64 rank); // Returns a new shape of zero dimensions. ShapeHandle Scalar(); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 6b3b5d36047..9d4a0a52f75 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -839,11 +839,6 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g, ShapeRefiner* refiner, std::vector>* return_tensors) { - ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry()); - if (refiner == nullptr) { - refiner = &default_refiner; - } - if (!opts.return_tensors.empty()) { if (return_tensors == nullptr) { return errors::InvalidArgument( @@ -857,6 +852,36 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, return_tensors->size(), ")"); } } + + ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry()); + if (refiner == nullptr) { + refiner = &default_refiner; + } else { + // Log a warning if we are importing a GraphDef at an older + // producer version after already having added non-source/sink + // nodes to the graph in the past. + if (gdef.versions().producer() > 0 && + gdef.versions().producer() < refiner->graph_def_version() && + g->num_nodes() > 2) { + LOG(WARNING) << "Importing a graph with a lower producer version " + << gdef.versions().producer() + << " into an existing graph with producer version " + << refiner->graph_def_version() << ". Shape inference will " + << "have run different parts of the graph with different " + << "producer versions."; + } + } + + // Set the graph def version of the refiner as the min of the + // current value and the version from the graph we are about to + // import. + // + // Note: to match Run() semantics, we should re-run shape inference + // on the entire graph if the producer version has changed. For now + // we log the warning above. + refiner->set_graph_def_version( + std::min(refiner->graph_def_version(), gdef.versions().producer())); + return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors); } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index e20dabc8910..e3b7f322cb6 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -2271,5 +2271,176 @@ TEST_F(GraphConstructorTest, GraphDefVersionMergingDuringImport) { EXPECT_EQ(3, graph_.versions().bad_consumers(2)); } +TEST_F(GraphConstructorTest, ImportGraphDefProvidedShapeRefinerVersions) { + ImportGraphDefOptions opts; + // A valid graph at producer version 20, but one + // that would not import if the graph_def_version were 21. + string gdef_ascii = strings::StrCat(R"EOF( +node { + name: "Sum/input" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + dim { + size: 1 + } + } + tensor_content: "\001\000\000\000\002\000\000\000" + } + } + } +} +node { + name: "Sum/reduction_indices" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + dim { + size: 1 + } + } + tensor_content: "\000\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "Sum" + op: "Sum" + input: "Sum/input" + input: "Sum/reduction_indices" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "keep_dims" + value { + b: false + } + } +} +versions { + producer: 20 +})EOF"); + + // Create a shape refiner with the latest TF_GRAPH_DEF_VERSION. + // Importing the graphdef with an existing refiner should + // make the refiner inherit the graphdef version from the + // passed in graphdef since it has a lower producer. + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); + ExpectOK(gdef_ascii, opts, &refiner); + + // Add another node with a higher producer + gdef_ascii = strings::StrCat(R"EOF( +node { + name: "RandomConst" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + dim { + size: 1 + } + } + tensor_content: "\001\000\000\000\002\000\000\000" + } + } + } +} +versions { + producer: 21 +})EOF"); + + ExpectOK(gdef_ascii, opts, &refiner); + // Check that the refiner's graph def version is the lowest of + // the graph defs we have seen so far. + EXPECT_EQ(20, refiner.graph_def_version()); + + // Add another node with a lower producer + gdef_ascii = strings::StrCat(R"EOF( +node { + name: "RandomConst2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + dim { + size: 1 + } + } + tensor_content: "\001\000\000\000\002\000\000\000" + } + } + } +} +versions { + producer: 17 +})EOF"); + ExpectOK(gdef_ascii, opts, &refiner); + + // Check that the refiner's graph def version is the lowest of + // the graph defs we have seen so far. + EXPECT_EQ(17, refiner.graph_def_version()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index c42eebae538..5d74d3d3b17 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -29,6 +29,16 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +cc_library( + name = "op_types", + srcs = ["op_types.cc"], + hdrs = ["op_types.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "utils", srcs = ["utils.cc"], @@ -88,6 +98,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":grappler_item", + ":op_types", ":utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib_internal", diff --git a/tensorflow/core/grappler/devices.cc b/tensorflow/core/grappler/devices.cc index d3fc9044d3b..b318ac22d4b 100644 --- a/tensorflow/core/grappler/devices.cc +++ b/tensorflow/core/grappler/devices.cc @@ -53,6 +53,22 @@ int GetNumAvailableGPUs() { return num_eligible_gpus; } +int64 AvailableGPUMemory(int gpu_id) { +#if GOOGLE_CUDA + // Look up the device, to see its attributes. + perftools::gputools::Platform* gpu_platform = GPUMachineManager(); + CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount()); + perftools::gputools::StreamExecutor* se = + gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie(); + int64 total_memory, available_memory; + CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory)); + + return available_memory; +#else + return 0; +#endif +} + int GetNumAvailableLogicalCPUCores() { return port::NumSchedulableCPUs(); } } // end namespace grappler diff --git a/tensorflow/core/grappler/devices.h b/tensorflow/core/grappler/devices.h index 329e8e2e655..2d6c41888d9 100644 --- a/tensorflow/core/grappler/devices.h +++ b/tensorflow/core/grappler/devices.h @@ -29,6 +29,10 @@ namespace grappler { // than 8. int GetNumAvailableGPUs(); +// Maximum amount of gpu memory available per gpu. gpu_id must be in the range +// [0, num_available_gpu) +int64 AvailableGPUMemory(int gpu_id); + // Get the number of logical CPU cores (aka hyperthreads) available. int GetNumAvailableLogicalCPUCores(); diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 7889b0e0259..e37b908fc67 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variable.pb.h" #include "tensorflow/core/grappler/inputs/utils.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -90,7 +91,7 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( node.clear_device(); } - if (node.op() == "Placeholder" || node.op() == "PlaceholderV2") { + if (IsPlaceholder(node)) { if (node.attr().count("dtype") == 0) { LOG(ERROR) << "Unknown type for placeholder " << node.name() << ", skipping this input"; diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc new file mode 100644 index 00000000000..33ef498db03 --- /dev/null +++ b/tensorflow/core/grappler/op_types.cc @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/op_types.h" + +namespace tensorflow { +namespace grappler { + +bool IsPlaceholder(const NodeDef& node) { + const auto op = node.op(); + return op == "Placeholder" || op == "PlaceholderV2"; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h new file mode 100644 index 00000000000..30a3c914116 --- /dev/null +++ b/tensorflow/core/grappler/op_types.h @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OP_TYPES_H_ +#define TENSORFLOW_GRAPPLER_OP_TYPES_H_ + +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { +namespace grappler { + +bool IsPlaceholder(const NodeDef& node); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OP_TYPES_H_ diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index bd96e2b33cc..2ea150ce188 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -25,6 +25,40 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +cc_library( + name = "auto_parallel", + srcs = ["auto_parallel.cc"], + hdrs = [ + "auto_parallel.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + ], +) + +cc_test( + name = "auto_parallel_test", + srcs = ["auto_parallel_test.cc"], + deps = [ + ":auto_parallel", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + cc_library( name = "constant_folding", srcs = ["constant_folding.cc"], @@ -179,6 +213,7 @@ cc_library( ":constant_folding", ":graph_optimizer", ":layout_optimizer", + ":memory_optimizer", ":model_pruner", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc new file mode 100644 index 00000000000..77ab178653b --- /dev/null +++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc @@ -0,0 +1,260 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/auto_parallel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/devices.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace grappler { +const char kAutoParallelPrefix[] = "AutoParallel"; + +NodeDef* AutoParallel::AddNodeDivConst() { + NodeDef* node = graph_.add_node(); + node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const")); + node->set_op("Const"); + + AttrValue attr_data_type; + attr_data_type.set_type(DT_FLOAT); + node->mutable_attr()->insert({"dtype", attr_data_type}); + + AttrValue attr_tensor; + auto tensor = attr_tensor.mutable_tensor(); + tensor->add_float_val(static_cast(num_replicas_)); + tensor->set_dtype(DT_FLOAT); + node->mutable_attr()->insert({"value", attr_tensor}); + return node; +} + +NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a, + const string& input_b) { + NodeDef* node = graph_.add_node(); + node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name)); + node->set_op("RealDiv"); + node->add_input(input_a); + node->add_input(input_b); + AttrValue attr_type; + attr_type.set_type(DT_FLOAT); + node->mutable_attr()->insert({"T", attr_type}); + return node; +} + +NodeDef* AutoParallel::AddNodeControl(const string& name, + const std::set& deps, + GraphDef* graph) { + NodeDef* node = graph->add_node(); + node->set_name(name); + node->set_op("NoOp"); + for (const auto& dep : deps) { + node->add_input(strings::StrCat("^", dep)); + } + return node; +} + +Status AutoParallel::Initialize(const GrapplerItem& item) { + num_gpus_ = GetNumAvailableGPUs(); + LOG(INFO) << "Number of GPUs: " << num_gpus_; + item_ = &item; + graph_ = item.graph; + LOG(INFO) << "Original graph size: " << graph_.node_size(); + if (item.fetch.empty()) { + return Status(error::INVALID_ARGUMENT, "No fetch nodes provided."); + } + + if (item.MainVariables().empty()) { + return Status(error::INVALID_ARGUMENT, "No variables provided."); + } + + for (const auto& init : item.init_ops) { + VLOG(1) << "Init node: " << init; + } + + for (const auto& fetch : item.fetch) { + VLOG(1) << "Fetch node: " << fetch; + } + + for (const auto& var : item.MainVariables()) { + VLOG(2) << "Variable: " << var->name(); + } + + std::set apply_gradients_ops = {"ApplyGradientDescent", + "ApplyProximalGradientDescent", + "ApplyAdadelta", + "ApplyAdagrad", + "ApplyProximalAdagrad", + "ApplyAdagradDA", + "ApplyFtrl", + "ApplyMomentum", + "ApplyAdam", + "ApplyRMSProp", + "ApplyCenteredRMSProp"}; + const NodeDef* dequeue_node = nullptr; + for (int i = 0; i < graph_.node_size(); i++) { + all_nodes_.insert( + std::make_pair(graph_.node(i).name(), graph_.mutable_node(i))); + if (graph_.node(i).op() == "QueueDequeueManyV2") { + dequeue_node = graph_.mutable_node(i); + } + if (apply_gradients_ops.find(graph_.node(i).op()) != + apply_gradients_ops.end()) { + apply_gradients_nodes_.insert(graph_.node(i).name()); + VLOG(2) << "Apply gradients node: " << graph_.node(i).name(); + } + } + + auto div_const_node = AddNodeDivConst(); + all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node)); + std::map gradient_pos = {{"ApplyGradientDescent", 2}, + {"ApplyProximalGradientDescent", 4}, + {"ApplyAdadelta", 6}, + {"ApplyAdagrad", 3}, + {"ApplyProximalAdagrad", 5}, + {"ApplyAdagradDA", 3}, + {"ApplyFtrl", 3}, + {"ApplyMomentum", 3}, + {"ApplyAdam", 9}, + {"ApplyRMSProp", 7}, + {"ApplyCenteredRMSProp", 8}}; + for (const auto& apply_gradient_node_name : apply_gradients_nodes_) { + auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op(); + auto apply_gradients_node = all_nodes_[apply_gradient_node_name]; + + auto div_node = AddNodeDiv( + apply_gradient_node_name, + apply_gradients_node->input(gradient_pos[apply_gradients_op]), + div_const_node->name()); + all_nodes_.insert(std::make_pair(div_node->name(), div_node)); + *apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) = + div_node->name(); + } + LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size(); + + auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch); + LOG(INFO) << "Number of training nodes: " << train_nodes.size(); + + std::vector input_nodes; + if (dequeue_node) { + LOG(INFO) << "Dequeue node: " << dequeue_node->name(); + input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()}); + } + LOG(INFO) << "Number of input nodes: " << input_nodes.size(); + + std::set dont_replicate_nodes; + for (const auto& variable : item.MainVariables()) { + dont_replicate_nodes.insert(variable->name()); + } + // Don't replicate all input nodes, except the dequeue node. + for (const auto& input_node : input_nodes) { + if (input_node->name() != dequeue_node->name()) { + dont_replicate_nodes.insert(input_node->name()); + } + } + + for (const auto& node : train_nodes) { + if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) { + replica_nodes_.insert(node->name()); + } + } + LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size(); + + for (const auto& node : all_nodes_) { + if (replica_nodes_.find(node.first) == replica_nodes_.end()) { + shared_nodes_.insert(node.first); + } + } + LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size(); + return Status::OK(); +} + +bool AutoParallel::NotSharedNode(const string& name) { + return shared_nodes_.find(name) == shared_nodes_.end(); +} + +void AutoParallel::AddSharedNodes(GraphDef* graph) { + string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0); + for (const auto& node : shared_nodes_) { + auto new_node = graph->add_node(); + *new_node = *all_nodes_[node]; + for (int i = 0; i < new_node->input_size(); i++) { + if (NotSharedNode(NodeName(new_node->input(i)))) { + string new_name = AddPrefixToNodeName(new_node->input(i), prefix); + *new_node->mutable_input(i) = new_name; + } + } + } +} + +void AutoParallel::AddOneReplica(GraphDef* graph, int number) { + string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number); + for (const auto& node : replica_nodes_) { + auto new_node = graph->add_node(); + *new_node = *all_nodes_[node]; + if (NotSharedNode(new_node->name())) { + new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix)); + if (num_gpus_ > 0) { + new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_)); + } + for (int i = 0; i < new_node->input_size(); i++) { + if (NotSharedNode(NodeName(new_node->input(i)))) { + string new_name = AddPrefixToNodeName(new_node->input(i), prefix); + *new_node->mutable_input(i) = new_name; + } + } + } + } +} + +void AutoParallel::BuildGraph(GraphDef* graph) { + AddSharedNodes(graph); + for (int i = 0; i < num_replicas_; i++) { + AddOneReplica(graph, i); + } + std::set fetches; + for (int i = 0; i < item_->fetch.size(); i++) { + for (int j = 0; j < num_replicas_; j++) { + string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j); + string fetch = AddPrefixToNodeName(item_->fetch[i], prefix); + fetches.insert(fetch); + } + } + string name_control = + strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch"); + auto control = AddNodeControl(name_control, fetches, graph); + + for (const auto& fetch : item_->fetch) { + AddNodeControl(fetch, {control->name()}, graph); + } + LOG(INFO) << "Parallelized graph size: " << graph->node_size(); +} + +Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + TF_RETURN_IF_ERROR(Initialize(item)); + BuildGraph(output); + return Status::OK(); +} + +void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) { + // TODO(yaozhang): Add feedback. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.h b/tensorflow/core/grappler/optimizers/auto_parallel.h new file mode 100644 index 00000000000..cac0db2c236 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/auto_parallel.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// Automatically parallelize a graph by splitting in the batch dimension. +class AutoParallel : public GraphOptimizer { + public: + AutoParallel(int num_replicas) : num_replicas_(num_replicas) {} + ~AutoParallel() override {} + + string name() const override { return "autoparallel"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; + + private: + GraphDef graph_; + std::map all_nodes_; + std::set apply_gradients_nodes_; + std::set replica_nodes_; + std::set shared_nodes_; + const GrapplerItem* item_; + int num_replicas_; + int num_gpus_; + Status Initialize(const GrapplerItem& item); + NodeDef* AddNodeDivConst(); + NodeDef* AddNodeDiv(const string& name, const string& input_a, + const string& input_b); + NodeDef* AddNodeControl(const string& name, const std::set& deps, + GraphDef* graph); + bool NotSharedNode(const string& name); + void AddSharedNodes(GraphDef* graph); + void AddOneReplica(GraphDef* graph, int number); + void BuildGraph(GraphDef* graph); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc new file mode 100644 index 00000000000..b7786ccd144 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/auto_parallel.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class AutoParallelTest : public ::testing::Test {}; + +TEST_F(AutoParallelTest, SimpleParallel) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1}); + Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1}); + Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT); + Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a}); + Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT}); + auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue}, + {constant_b}, {DT_FLOAT}); + Output add = ops::AddN(s.WithOpName("add"), {constant_a, dequeue[0]}); + Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1}); + Output apply_gradient = ops::ApplyGradientDescent( + s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add}); + + GrapplerItem item; + item.init_ops.push_back("assign"); + item.fetch.push_back("apply_gradient"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + AutoParallel parallel(2); + GraphDef output; + Status status = parallel.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_EQ(20, output.node_size()); + + const NodeDef& node_assign = output.node(0); + EXPECT_EQ("assign", node_assign.name()); + EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_assign.input(1)); + + const NodeDef& node_constant_b = output.node(1); + EXPECT_EQ("constant_b", node_constant_b.name()); + + const NodeDef& node_fifo_queue = output.node(2); + EXPECT_EQ("fifo_queue", node_fifo_queue.name()); + + const NodeDef& node_var = output.node(3); + EXPECT_EQ("var", node_var.name()); + + const NodeDef& node_div_const0 = output.node(4); + EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-Const", + node_div_const0.name()); + + const NodeDef& node_div0 = output.node(5); + EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-apply_gradient", + node_div0.name()); + const NodeDef& node_add0 = output.node(6); + EXPECT_EQ("AutoParallel-Replica-0-add", node_add0.name()); + + const NodeDef& node_gradient0 = output.node(7); + EXPECT_EQ("AutoParallel-Replica-0-apply_gradient", node_gradient0.name()); + + const NodeDef& node_constant_a0 = output.node(8); + EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_constant_a0.name()); + + const NodeDef& node_dequeue0 = output.node(9); + EXPECT_EQ("AutoParallel-Replica-0-dequeue", node_dequeue0.name()); + + const NodeDef& node_learning_rate0 = output.node(10); + EXPECT_EQ("AutoParallel-Replica-0-learning_rate", node_learning_rate0.name()); + + const NodeDef& node_div_const1 = output.node(11); + EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-Const", + node_div_const1.name()); + + const NodeDef& node_div1 = output.node(12); + EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-apply_gradient", + node_div1.name()); + + const NodeDef& node_add1 = output.node(13); + EXPECT_EQ("AutoParallel-Replica-1-add", node_add1.name()); + + const NodeDef& node_gradient1 = output.node(14); + EXPECT_EQ("AutoParallel-Replica-1-apply_gradient", node_gradient1.name()); + + const NodeDef& node_constant_a1 = output.node(15); + EXPECT_EQ("AutoParallel-Replica-1-constant_a", node_constant_a1.name()); + + const NodeDef& node_dequeue1 = output.node(16); + EXPECT_EQ("AutoParallel-Replica-1-dequeue", node_dequeue1.name()); + + const NodeDef& node_learning_rate1 = output.node(17); + EXPECT_EQ("AutoParallel-Replica-1-learning_rate", node_learning_rate1.name()); + + const NodeDef& node_fetch = output.node(18); + EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name()); + EXPECT_EQ("^AutoParallel-Replica-0-apply_gradient", node_fetch.input(0)); + EXPECT_EQ("^AutoParallel-Replica-1-apply_gradient", node_fetch.input(1)); + + const NodeDef& node_gradient = output.node(19); + EXPECT_EQ("apply_gradient", node_gradient.name()); + EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 67ffa7a4b6e..0fe9359b753 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" +#include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/lib/core/status.h" @@ -37,6 +38,9 @@ std::unique_ptr MetaOptimizer::NewOptimizer( if (optimizer == "layout") { graph_optimizer.reset(new LayoutOptimizer()); } + if (optimizer == "memory") { + graph_optimizer.reset(new MemoryOptimizer()); + } return graph_optimizer; } @@ -55,8 +59,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back( std::unique_ptr(new LayoutOptimizer())); } + if (cfg_.memory_optimization() > 0) { + optimizers.push_back( + std::unique_ptr(new MemoryOptimizer())); + } } else { - std::set avaliable_optimizers = {"pruning", "constfold", "layout"}; + std::set avaliable_optimizers = {"pruning", "constfold", "layout", + "memory"}; for (const auto& optimizer : cfg_.optimizers()) { if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) { optimizers.push_back(NewOptimizer(optimizer)); @@ -81,7 +90,6 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizer->Optimize(nullptr, optimized_item, optimized_graph)); } } - // Copy the graph version. *optimized_graph->mutable_versions() = item.graph.versions(); diff --git a/tensorflow/core/kernels/batchtospace_op.cc b/tensorflow/core/kernels/batchtospace_op.cc index b24a8340839..99b5d3daaa4 100644 --- a/tensorflow/core/kernels/batchtospace_op.cc +++ b/tensorflow/core/kernels/batchtospace_op.cc @@ -97,6 +97,10 @@ static void BatchToSpaceOpCompute(OpKernelContext* context, for (int block_dim = 0; block_dim < block_dims; ++block_dim) { block_shape_product *= block_shape[block_dim]; } + OP_REQUIRES( + context, block_shape_product > 0, + errors::InvalidArgument("Product of block sizes must be positive, got ", + block_shape_product)); const int64 orig_input_batch_size = orig_input_tensor.dim_size(0); OP_REQUIRES( diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index caf73420ba9..746fe63e2a0 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -216,12 +216,14 @@ struct CropAndResize { const float x_lerp = in_x - left_x_index; for (int d = 0; d < depth; ++d) { - const float top_left(image(b_in, top_y_index, left_x_index, d)); - const float top_right(image(b_in, top_y_index, right_x_index, d)); - const float bottom_left( - image(b_in, bottom_y_index, left_x_index, d)); - const float bottom_right( - image(b_in, bottom_y_index, right_x_index, d)); + const float top_left( + static_cast(image(b_in, top_y_index, left_x_index, d))); + const float top_right( + static_cast(image(b_in, top_y_index, right_x_index, d))); + const float bottom_left(static_cast( + image(b_in, bottom_y_index, left_x_index, d))); + const float bottom_right(static_cast( + image(b_in, bottom_y_index, right_x_index, d))); const float top = top_left + (top_right - top_left) * x_lerp; const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; @@ -545,12 +547,14 @@ struct CropAndResizeBackpropBoxes { const float x_lerp = in_x - left_x_index; for (int d = 0; d < depth; ++d) { - const float top_left(image(b_in, top_y_index, left_x_index, d)); - const float top_right(image(b_in, top_y_index, right_x_index, d)); - const float bottom_left( - image(b_in, bottom_y_index, left_x_index, d)); - const float bottom_right( - image(b_in, bottom_y_index, right_x_index, d)); + const float top_left( + static_cast(image(b_in, top_y_index, left_x_index, d))); + const float top_right( + static_cast(image(b_in, top_y_index, right_x_index, d))); + const float bottom_left(static_cast( + image(b_in, bottom_y_index, left_x_index, d))); + const float bottom_right(static_cast( + image(b_in, bottom_y_index, right_x_index, d))); // Compute the image gradient. float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) + x_lerp * (bottom_right - top_right); @@ -606,18 +610,25 @@ inline void CheckValidBoxInd( .HostMemory("crop_size"), \ CropAndResizeOp); \ \ - REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("image_size"), \ - CropAndResizeGradImageOp); \ - \ REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ CropAndResizeGradBoxesOp); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); + +#undef REGISTER_KERNEL + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradImage") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("image_size"), \ + CropAndResizeGradImageOp); + +TF_CALL_half(REGISTER_KERNEL); TF_CALL_float(REGISTER_KERNEL); +TF_CALL_double(REGISTER_KERNEL); #undef REGISTER_KERNEL @@ -685,7 +696,7 @@ inline void CheckValidBoxInd( .TypeConstraint("T"), \ CropAndResizeGradBoxesOp); -TF_CALL_float(REGISTER_KERNEL); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNEL); #undef REGISTER_KERNEL diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc index 75146b28e66..254475db465 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc @@ -88,26 +88,26 @@ __global__ void CropAndResizeKernel( const int right_x_index = ceilf(in_x); const float x_lerp = in_x - left_x_index; - const float top_left( + const float top_left(static_cast( image_ptr[((b_in * image_height + top_y_index) * image_width + left_x_index) * depth + - d]); - const float top_right( + d])); + const float top_right(static_cast( image_ptr[((b_in * image_height + top_y_index) * image_width + right_x_index) * depth + - d]); - const float bottom_left( + d])); + const float bottom_left(static_cast( image_ptr[((b_in * image_height + bottom_y_index) * image_width + left_x_index) * depth + - d]); - const float bottom_right( + d])); + const float bottom_right(static_cast( image_ptr[((b_in * image_height + bottom_y_index) * image_width + right_x_index) * depth + - d]); + d])); const float top = top_left + (top_right - top_left) * x_lerp; const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; crops_ptr[out_idx] = top + (bottom - top) * y_lerp; @@ -258,26 +258,26 @@ __global__ void CropAndResizeBackpropBoxesKernel( const int right_x_index = ceilf(in_x); const float x_lerp = in_x - left_x_index; - const float top_left = + const float top_left(static_cast( image_ptr[((b_in * image_height + top_y_index) * image_width + left_x_index) * depth + - d]; - const float top_right = + d])); + const float top_right(static_cast( image_ptr[((b_in * image_height + top_y_index) * image_width + right_x_index) * depth + - d]; - const float bottom_left = + d])); + const float bottom_left(static_cast( image_ptr[((b_in * image_height + bottom_y_index) * image_width + left_x_index) * depth + - d]; - const float bottom_right = + d])); + const float bottom_right(static_cast( image_ptr[((b_in * image_height + bottom_y_index) * image_width + right_x_index) * depth + - d]; + d])); // Compute the image gradient. float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) + @@ -436,7 +436,7 @@ struct CropAndResizeBackpropBoxes { template struct CropAndResizeBackpropImage; \ template struct CropAndResizeBackpropBoxes; -TF_CALL_float(DEFINE_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); #undef DEFINE_GPU_SPECS diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc index 68e077e44df..3a7f180598e 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_test.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" @@ -31,9 +32,10 @@ namespace tensorflow { class CropAndResizeOpTest : public OpsTestBase { protected: + template void MakeOp(float extrapolation_value) { TF_EXPECT_OK(NodeDefBuilder("crop_and_resize_op", "CropAndResize") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DataTypeToEnum::value)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_INT32)) .Input(FakeInput(DT_INT32)) @@ -43,12 +45,33 @@ class CropAndResizeOpTest : public OpsTestBase { } }; -TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1) { - MakeOp(0); +#define REGISTER_TEST(T) \ + TEST_F(CropAndResizeOpTest, TestCropAndResize##T) { \ + MakeOp(0); \ + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); \ + AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); \ + AddInputFromArray(TensorShape({1}), {0}); \ + AddInputFromArray(TensorShape({2}), {1, 1}); \ + TF_ASSERT_OK(RunOpKernel()); \ + \ + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1})); \ + test::FillValues(&expected, {2.5}); \ + test::ExpectTensorEqual(expected, *GetOutput(0)); \ + } + +REGISTER_TEST(float) +REGISTER_TEST(double) +REGISTER_TEST(int8) +REGISTER_TEST(uint8) + +#undef REGISTER_TEST + +TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Uint8) { + MakeOp(0); // Input: // 1, 2 // 3, 4 - AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); + AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({1}), {0}); AddInputFromArray(TensorShape({2}), {1, 1}); @@ -60,7 +83,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1) { } TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Flipped) { - MakeOp(0); + MakeOp(0); // Input: // 1, 2 // 3, 4 @@ -76,7 +99,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To1x1Flipped) { } TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3) { - MakeOp(0); + MakeOp(0); // Input: // 1, 2 // 3, 4 @@ -97,7 +120,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3) { } TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Flipped) { - MakeOp(0); + MakeOp(0); // Input: // 1, 2 // 3, 4 @@ -118,7 +141,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Flipped) { } TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2) { - MakeOp(0); + MakeOp(0); // Input: // 1, 2, 3 // 4, 5, 6 @@ -143,7 +166,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2) { } TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2Flipped) { - MakeOp(0); + MakeOp(0); // Input: // 1, 2, 3 // 4, 5, 6 @@ -169,7 +192,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize3x3To2x2Flipped) { TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) { const float v = -1; - MakeOp(v); + MakeOp(v); // Input: // 1, 2 // 3, 4 @@ -190,7 +213,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3Extrapolated) { } TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) { - MakeOp(0); + MakeOp(0); // Input: // 1, 2 // 3, 4 @@ -208,7 +231,7 @@ TEST_F(CropAndResizeOpTest, TestCropAndResize2x2To3x3NoCrop) { } TEST_F(CropAndResizeOpTest, TestInvalidInputShape) { - MakeOp(0); + MakeOp(0); AddInputFromArray(TensorShape({2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({1}), {0}); @@ -220,7 +243,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) { } TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { - MakeOp(0); + MakeOp(0); AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({2}), {0, 0}); @@ -233,7 +256,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { } TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) { - MakeOp(0); + MakeOp(0); AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({1}), {1}); diff --git a/tensorflow/core/kernels/gather_functor.cc b/tensorflow/core/kernels/gather_functor.cc index be220d5c95d..8ef027a1ddd 100644 --- a/tensorflow/core/kernels/gather_functor.cc +++ b/tensorflow/core/kernels/gather_functor.cc @@ -38,6 +38,7 @@ namespace functor { DECLARE_GPU_SPECS_INDEX(T, int64) TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +TF_CALL_complex64(DECLARE_GPU_SPECS); #undef DECLARE_GPU_SPECS #undef DECLARE_GPU_SPECS_INDEX diff --git a/tensorflow/core/kernels/gather_functor_gpu.cu.cc b/tensorflow/core/kernels/gather_functor_gpu.cu.cc index f1c10250786..456f4023a79 100644 --- a/tensorflow/core/kernels/gather_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/gather_functor_gpu.cu.cc @@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice; DEFINE_GPU_SPECS_INDEX(T, int64); TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); +TF_CALL_complex64(DEFINE_GPU_SPECS); #undef DEFINE_GPU_SPECS #undef DEFINE_GPU_SPECS_INDEX diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index d8182218af1..31af37693c5 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -114,6 +114,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU); +TF_CALL_complex64(REGISTER_GATHER_GPU); #undef REGISTER_GATHER_GPU diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index c340223aa10..23645dafad4 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -40,9 +40,9 @@ namespace { class GatherOpTest : public OpsTestBase { protected: - void MakeOp(DataType index_type) { + void MakeOp(DataType data_type, DataType index_type) { TF_ASSERT_OK(NodeDefBuilder("myop", "Gather") - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(data_type)) .Input(FakeInput(index_type)) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); @@ -50,7 +50,7 @@ class GatherOpTest : public OpsTestBase { }; TEST_F(GatherOpTest, ScalarIndices) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5}), {0, 1, 2, 3, 4}); @@ -63,8 +63,26 @@ TEST_F(GatherOpTest, ScalarIndices) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +TEST_F(GatherOpTest, ScalarIndices_Complex) { + MakeOp(DT_COMPLEX64, DT_INT32); + + // Feed and run + AddInputFromArray>( + TensorShape({5}), {std::complex(0, 10), std::complex(1, 11), + std::complex(2, 12), std::complex(3, 13), + std::complex(4, 14)}); + AddInputFromArray(TensorShape({}), {3}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_COMPLEX64, TensorShape({})); + test::FillValues>(&expected, + {std::complex(3, 13)}); + test::ExpectTensorEqual>(expected, *GetOutput(0)); +} + TEST_F(GatherOpTest, Simple_TwoD32) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5, 3}), @@ -79,7 +97,7 @@ TEST_F(GatherOpTest, Simple_TwoD32) { } TEST_F(GatherOpTest, ZeroSize_TwoD32) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5, 0}), {}); @@ -92,7 +110,7 @@ TEST_F(GatherOpTest, ZeroSize_TwoD32) { } TEST_F(GatherOpTest, Simple_TwoD64) { - MakeOp(DT_INT64); + MakeOp(DT_FLOAT, DT_INT64); // Feed and run AddInputFromArray(TensorShape({5, 3}), @@ -107,7 +125,7 @@ TEST_F(GatherOpTest, Simple_TwoD64) { } TEST_F(GatherOpTest, HighRank) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT, DT_INT32); // Feed and run AddInputFromArray(TensorShape({4}), {0, 1, 2, 3}); @@ -121,7 +139,7 @@ TEST_F(GatherOpTest, HighRank) { } TEST_F(GatherOpTest, Error_IndexOutOfRange) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5, 3}), diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index c5d5657492a..a383cc8199a 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -295,6 +295,83 @@ static void RunFusedGraph(const GraphDef& fused_graph_def) { reinterpret_cast(output_tensor.flat().data())); } +static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0, + const GraphTransferInfo& gfi1) { + LOG(INFO) << "(1) node count: " << gfi1.node_info_size() << ", " + << gfi1.const_node_info_size(); + + // 1. check node_info + ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size()); + for (int i = 0; i < gfi0.node_info_size(); ++i) { + const GraphTransferInfo::NodeInfo& ni0 = gfi0.node_info(i); + const GraphTransferInfo::NodeInfo& ni1 = gfi1.node_info(i); + EXPECT_EQ(ni0.DebugString(), ni1.DebugString()); + EXPECT_EQ(ni0.ByteSize(), ni1.ByteSize()); + } + + // 2. check const_node_info + ASSERT_EQ(gfi0.const_node_info_size(), gfi1.const_node_info_size()); + for (int i = 0; i < gfi0.const_node_info_size(); ++i) { + const GraphTransferInfo::ConstNodeInfo& cni0 = gfi0.const_node_info(i); + const GraphTransferInfo::ConstNodeInfo& cni1 = gfi1.const_node_info(i); + ASSERT_EQ(cni0.shape_size(), cni1.shape_size()); + for (int j = 0; j < cni0.shape_size(); ++j) { + EXPECT_EQ(cni0.shape(j), cni1.shape(j)); + } + EXPECT_EQ(cni0.ByteSize(), cni1.ByteSize()); + EXPECT_EQ(cni0.DebugString(), cni1.DebugString()); + } + + // 3. check node_input_info + ASSERT_EQ(gfi0.node_input_info_size(), gfi1.node_input_info_size()); + for (int i = 0; i < gfi0.node_input_info_size(); ++i) { + const GraphTransferInfo::NodeInputInfo& nii0 = gfi0.node_input_info(i); + const GraphTransferInfo::NodeInputInfo& nii1 = gfi1.node_input_info(i); + EXPECT_EQ(nii0.ByteSize(), nii1.ByteSize()); + EXPECT_EQ(nii0.DebugString(), nii1.DebugString()); + } + + // 4. check node_output_info + ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size()); + for (int i = 0; i < gfi0.node_output_info_size(); ++i) { + const GraphTransferInfo::NodeOutputInfo& noi0 = gfi0.node_output_info(i); + const GraphTransferInfo::NodeOutputInfo& noi1 = gfi1.node_output_info(i); + ASSERT_EQ(noi0.max_byte_size_size(), noi1.max_byte_size_size()); + for (int j = 0; j < noi0.max_byte_size_size(); ++j) { + EXPECT_EQ(noi0.max_byte_size(j), noi1.max_byte_size(j)); + } + EXPECT_EQ(noi0.ByteSize(), noi1.ByteSize()); + EXPECT_EQ(noi0.DebugString(), noi1.DebugString()); + } + + // 5. check graph_input_node_info + ASSERT_EQ(gfi0.graph_input_node_info_size(), + gfi1.graph_input_node_info_size()); + for (int i = 0; i < gfi0.graph_input_node_info_size(); ++i) { + const GraphTransferInfo::GraphInputNodeInfo& gini0 = + gfi0.graph_input_node_info(i); + const GraphTransferInfo::GraphInputNodeInfo& gini1 = + gfi0.graph_input_node_info(i); + EXPECT_EQ(gini0.ByteSize(), gini1.ByteSize()); + EXPECT_EQ(gini0.DebugString(), gini1.DebugString()); + } + + // 6. check graph_output_node_info + ASSERT_EQ(gfi0.graph_output_node_info_size(), + gfi1.graph_output_node_info_size()); + for (int i = 0; i < gfi0.graph_output_node_info_size(); ++i) { + const GraphTransferInfo::GraphOutputNodeInfo& goni0 = + gfi0.graph_output_node_info(i); + const GraphTransferInfo::GraphOutputNodeInfo& goni1 = + gfi0.graph_output_node_info(i); + EXPECT_EQ(goni0.ByteSize(), goni1.ByteSize()); + EXPECT_EQ(goni0.DebugString(), goni1.DebugString()); + } + + // 7. check destination + EXPECT_EQ(gfi0.destination(), gfi1.destination()); +} + // CAVEAT: This test only runs when you specify hexagon library using // makefile. // CAVEAT: This test is disabled by default because hexagon can keep only @@ -450,34 +527,22 @@ TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) { prof1.Stop(); prof1.DumpStatistics("Estiame shape by shape inference"); - LOG(INFO) << "(1) node count: " << gfi1.node_info_size() << ", " - << gfi1.const_node_info_size(); + CompareGraphTransferInfo(gfi0, gfi1); - ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size()); + const RemoteFusedGraphExecuteInfo ei0 = + BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi0); + const RemoteFusedGraphExecuteInfo ei1 = + BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi1); - ASSERT_EQ(gt0.GetGraphTransferInfo().const_node_info_size(), - gt1.GetGraphTransferInfo().const_node_info_size()); + GraphTransferInfo rgfi0; + rgfi0.ParseFromString(ei0.serialized_executor_parameters()); + GraphTransferInfo rgfi1; + rgfi1.ParseFromString(ei1.serialized_executor_parameters()); - for (int i = 0; i < gfi0.const_node_info_size(); ++i) { - const GraphTransferInfo::ConstNodeInfo& ni0 = gfi0.const_node_info(i); - const GraphTransferInfo::ConstNodeInfo& ni1 = gfi1.const_node_info(i); - ASSERT_EQ(ni0.shape_size(), ni1.shape_size()); - for (int j = 0; j < ni0.shape_size(); ++j) { - EXPECT_EQ(ni0.shape(j), ni1.shape(j)); - } - } - - ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size()); - for (int i = 0; i < gfi0.node_output_info_size(); ++i) { - const GraphTransferInfo::NodeOutputInfo& no0 = gfi0.node_output_info(i); - const GraphTransferInfo::NodeOutputInfo& no1 = gfi1.node_output_info(i); - ASSERT_EQ(no0.max_byte_size_size(), no1.max_byte_size_size()); - for (int j = 0; j < no0.max_byte_size_size(); ++j) { - EXPECT_EQ(no0.max_byte_size(j), no1.max_byte_size(j)); - } - } + CompareGraphTransferInfo(rgfi0, rgfi1); + CompareGraphTransferInfo(gfi0, rgfi0); + CompareGraphTransferInfo(gfi1, rgfi1); } - #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc index 851d87b15bb..ad9200e9489 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc @@ -174,6 +174,7 @@ const std::unordered_map OP_NAME_TO_SOC_OP_TYPE_MAP{ {"Placeholder", SupportedOpType::NOP}, {"RequantizationRange", SupportedOpType::REQUANTIZATION_RANGE_32}, {"Requantize", SupportedOpType::REQUANTIZE_32_TO_8}, + {"QuantizedReshape", SupportedOpType::QUANTIZED_RESHAPE}, }; /* static */ const IGraphTransferOpsDefinitions& diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index eb590280c9e..6cb56797bff 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -587,8 +587,8 @@ class MaxPoolingGradGradOp : public OpKernel { errors::InvalidArgument("out_grad_backprop must be 4-dimensional")); Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {2}, 0, tensor_out.shape(), &output)); + OP_REQUIRES_OK(context, + context->allocate_output(0, tensor_out.shape(), &output)); PoolParameters params{context, ksize_, stride_, padding_, data_format_, tensor_in.shape()}; diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index 32b210ecb7f..e3a57d2f28a 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -70,7 +70,7 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, int wend = min(wstart + kernel_w, width); hstart = max(hstart, 0); wstart = max(wstart, 0); - dtype maxval = -FLT_MAX; + dtype maxval = Eigen::NumTraits::lowest(); int maxidx = -1; const dtype* bottom_data_n = bottom_data + n * channels * height * width; for (int h = hstart; h < hend; ++h) { @@ -312,9 +312,6 @@ __global__ void MaxPoolGradBackwardNoMaskNHWC( // bottom_offset: the pre-computed per-image offset of the maxpool output. // This is equal to Hout*Wout*C. // bottom_diff: the gradient of the gradient w.r.t. output. -// This function relies on CudaAtomicAdd to avoid race conditions. Also, before -// the kernel is run, you will need to make sure that bottom_diff is filled with -// zero first. template __global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff, const int64* mask, const int top_offset, @@ -357,12 +354,12 @@ bool MaxPoolBackwardNoMask::operator()( const int stride_w, const int pad_t, const int pad_l, const T* top_diff, T* bottom_diff, const Eigen::GpuDevice& d) { const int kThreadsPerBlock = 1024; - const int bottom_size = batch * channels * height * width; - const int top_size = batch * channels * pooled_height * pooled_width; + const int bottom_size = batch * channels * height * width; SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff); + const int top_size = batch * channels * pooled_height * pooled_width; MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream()>>>( diff --git a/tensorflow/core/kernels/quantize_op.cc b/tensorflow/core/kernels/quantize_op.cc index 7b34c32cebd..f649287fc1d 100644 --- a/tensorflow/core/kernels/quantize_op.cc +++ b/tensorflow/core/kernels/quantize_op.cc @@ -86,6 +86,7 @@ class QuantizeV2Op : public OpKernel { fabsf(input_max_range))) / 100.0f; max_range = std::max(input_max_range, min_range + epsilon); + max_range = std::max(0.0f, max_range); Tensor* output = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); diff --git a/tensorflow/core/kernels/quantize_op_test.cc b/tensorflow/core/kernels/quantize_op_test.cc index 41996852f16..48bde3b4971 100644 --- a/tensorflow/core/kernels/quantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_op_test.cc @@ -132,6 +132,50 @@ TEST_F(QuantizedOpTest, QuantizeV2EqualRange) { EXPECT_LT(0.0f, output_max); } +TEST_F(QuantizedOpTest, QuantizeV2MovesMinToIncludeZero) { + TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DataTypeToEnum::v()) + .Attr("mode", "MIN_FIRST") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({3}), {0.1, 0.2, 0.3}); + AddInputFromArray(TensorShape({1}), {0.1}); + AddInputFromArray(TensorShape({1}), {0.3}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_QUINT8, TensorShape({3})); + test::FillValues(&expected, {85, 170, 255}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + const float output_min = GetOutput(1)->flat()(0); + const float output_max = GetOutput(2)->flat()(0); + EXPECT_NEAR(0.0f, output_min, 1e-5f); + EXPECT_NEAR(0.3f, output_max, 1e-5f); +} + +TEST_F(QuantizedOpTest, QuantizeV2MovesMaxToIncludeZero) { + TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DataTypeToEnum::v()) + .Attr("mode", "MIN_FIRST") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({3}), {-0.1, -0.2, -0.3}); + AddInputFromArray(TensorShape({1}), {-0.3}); + AddInputFromArray(TensorShape({1}), {-0.1}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_QUINT8, TensorShape({3})); + test::FillValues(&expected, {170, 85, 0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + const float output_min = GetOutput(1)->flat()(0); + const float output_max = GetOutput(2)->flat()(0); + EXPECT_NEAR(-0.3f, output_min, 1e-5f); + EXPECT_NEAR(0.0f, output_max, 1e-5f); +} + TEST_F(QuantizedOpTest, Dequantize) { TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "Dequantize") .Input(FakeInput(DT_QUINT8)) diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 3063fedac8f..80b1be8d4ca 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -178,27 +178,9 @@ namespace { static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape, int index, Tensor** output) { - if (!ctx->op_kernel().IsLegacyVector(shape.shape())) { - return errors::InvalidArgument( - "shape must be a vector of {int32,int64}, got shape ", - shape.shape().DebugString()); - } - if (shape.dtype() == DataType::DT_INT32) { - auto vec = shape.flat(); - TensorShape tensor_shape; - TF_RETURN_IF_ERROR( - TensorShapeUtils::MakeShape(vec.data(), vec.size(), &tensor_shape)); - TF_RETURN_IF_ERROR(ctx->allocate_output(index, tensor_shape, output)); - } else if (shape.dtype() == DataType::DT_INT64) { - auto vec = shape.flat(); - TensorShape tensor_shape; - TF_RETURN_IF_ERROR( - TensorShapeUtils::MakeShape(vec.data(), vec.size(), &tensor_shape)); - TF_RETURN_IF_ERROR(ctx->allocate_output(index, tensor_shape, output)); - } else { - return errors::InvalidArgument("shape must be a vector of {int32,int64}."); - } - return Status::OK(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR(ctx->op_kernel().MakeShape(shape, &tensor_shape)); + return ctx->allocate_output(index, tensor_shape, output); } // For now, use the same interface as RandomOp, so we can choose either one @@ -465,6 +447,12 @@ class RandomGammaOp : public OpKernel { #define REGISTER(TYPE) \ template struct functor::FillPhiloxRandom< \ CPUDevice, random::UniformDistribution >; \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, random::NormalDistribution >; \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter, TYPE> >; \ REGISTER_KERNEL_BUILDER( \ Name("RandomUniform") \ .Device(DEVICE_CPU) \ diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc index 553a4a7f939..66123e47c6e 100644 --- a/tensorflow/core/kernels/random_poisson_op.cc +++ b/tensorflow/core/kernels/random_poisson_op.cc @@ -291,33 +291,15 @@ class RandomPoissonOp : public OpKernel { const Tensor& shape_t = ctx->input(0); const Tensor& rate_t = ctx->input(1); - OP_REQUIRES(ctx, - TensorShapeUtils::IsVector(shape_t.shape()) && - (shape_t.dtype() == DataType::DT_INT32 || - shape_t.dtype() == DataType::DT_INT64), - errors::InvalidArgument( - "shape must be a vector of {int32,int64}, got shape: ", - shape_t.DebugString())); TensorShape samples_shape; - if (shape_t.dtype() == DataType::DT_INT32) { - auto vec = shape_t.flat(); - OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(), - &samples_shape)); - } else if (shape_t.dtype() == DataType::DT_INT64) { - auto vec = shape_t.flat(); - OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(), - &samples_shape)); - } + OP_REQUIRES_OK(ctx, MakeShape(shape_t, &samples_shape)); const int64 num_samples = samples_shape.num_elements(); - OP_REQUIRES(ctx, num_samples > 0, - errors::InvalidArgument( - "Input shape should have non-zero element count, got: ", - num_samples)); samples_shape.AppendShape(rate_t.shape()); // Allocate output samples. Tensor* samples_t = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t)); + if (num_samples == 0) return; const auto rate_flat = rate_t.flat().data(); const int64 num_rate = rate_t.NumElements(); diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index 2e099565783..c665bc5b03c 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -47,8 +47,9 @@ void ValidateInputs(bool is_save_op, OpKernelContext* context, context, prefix.NumElements() == 1, errors::InvalidArgument("Input prefix should have a single element, got ", prefix.NumElements(), " instead.")); - OP_REQUIRES(context, TensorShapeUtils::IsVector(tensor_names.shape()) && - TensorShapeUtils::IsVector(shape_and_slices.shape()), + OP_REQUIRES(context, + TensorShapeUtils::IsVector(tensor_names.shape()) && + TensorShapeUtils::IsVector(shape_and_slices.shape()), errors::InvalidArgument( "Input tensor_names and shape_and_slices " "should be an 1-D tensors, got ", @@ -105,6 +106,7 @@ class SaveV2 : public OpKernel { const auto& shape_and_slices_flat = shape_and_slices.flat(); BundleWriter writer(Env::Default(), prefix_string); + OP_REQUIRES_OK(context, writer.status()); VLOG(1) << "BundleWriter, prefix_string: " << prefix_string; for (int i = 0; i < num_tensors; ++i) { diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc index 3815716ccd9..c513683918e 100644 --- a/tensorflow/core/kernels/spacetobatch_op.cc +++ b/tensorflow/core/kernels/spacetobatch_op.cc @@ -100,6 +100,10 @@ void SpaceToBatchOpCompute(OpKernelContext* context, for (int block_dim = 0; block_dim < block_dims; ++block_dim) { block_shape_product *= block_shape[block_dim]; } + OP_REQUIRES( + context, block_shape_product > 0, + errors::InvalidArgument("Product of block sizes must be positive, got ", + block_shape_product)); const int internal_block_dims = block_dims - removed_prefix_block_dims - removed_suffix_block_dims; diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h index 1fec5a3b441..b2adb4462ba 100644 --- a/tensorflow/core/lib/random/philox_random.h +++ b/tensorflow/core/lib/random/philox_random.h @@ -101,12 +101,15 @@ class Array { // 2. PhiloxRandom is compilable by gcc and nvcc. class PhiloxRandom { public: - typedef Array ResultType; - typedef uint32 ResultElementType; + using ResultType = Array; + using ResultElementType = uint32; // The number of elements that will be returned. static const int kResultElementCount = 4; // Cost of generation of a single element (in cycles). static const int kElementCost = 10; + // The type for the 64-bit key stored in the form of two 32-bit uint + // that are used in the diffusion process. + using Key = Array; PHILOX_DEVICE_INLINE PhiloxRandom() {} @@ -125,6 +128,9 @@ class PhiloxRandom { counter_[3] = static_cast(seed_hi >> 32); } + PHILOX_DEVICE_INLINE + PhiloxRandom(ResultType counter, Key key) : counter_(counter), key_(key) {} + // Skip the specified number of samples of 128-bits in the current stream. PHILOX_DEVICE_INLINE void Skip(uint64 count) { @@ -178,10 +184,6 @@ class PhiloxRandom { } private: - // The type for the 64-bit key stored in the form of two 32-bit uint - // that are used in the diffusion process. - typedef Array Key; - // We use the same constants as recommended by the original paper. static const uint32 kPhiloxW32A = 0x9E3779B9; static const uint32 kPhiloxW32B = 0xBB67AE85; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index e81490c4988..e2e07a4bf19 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -41,10 +41,10 @@ Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack, } template -std::vector AsInt64(const Tensor* tensor, int num_elements) { +std::vector AsInt64(const Tensor* tensor, int64 num_elements) { std::vector ret(num_elements); auto data = tensor->vec(); - for (int i = 0; i < num_elements; ++i) { + for (int64 i = 0; i < num_elements; ++i) { ret[i] = data(i); } return ret; @@ -52,11 +52,11 @@ std::vector AsInt64(const Tensor* tensor, int num_elements) { template Status PadKnown(InferenceContext* c, ShapeHandle input, - const Tensor* paddings_t, int32 num_dims) { + const Tensor* paddings_t, int64 num_dims) { // paddings_t is known. std::vector dims(num_dims); auto paddings_data = paddings_t->matrix(); - for (int i = 0; i < num_dims; ++i) { + for (int64 i = 0; i < num_dims; ++i) { const T pad0 = paddings_data(i, 0); const T pad1 = paddings_data(i, 1); if (pad0 < 0 || pad1 < 0) { @@ -1244,9 +1244,12 @@ REGISTER_OP("_ParallelConcatStart") .Attr("dtype: type") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { - ShapeHandle out; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); - c->set_output(0, out); + TensorShapeProto shape_proto; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_proto)); + ShapeHandle output_shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromShapeProto(shape_proto, &output_shape)); + c->set_output(0, output_shape); return Status::OK(); }) .Doc(R"doc( @@ -2644,10 +2647,10 @@ output: The padded tensor. namespace { template Status MirrorPadKnown(InferenceContext* c, ShapeHandle input, - const Tensor* paddings_t, int32 input_rank) { + const Tensor* paddings_t, int64 input_rank) { auto paddings_data = paddings_t->matrix(); std::vector dims(input_rank); - for (int i = 0; i < input_rank; ++i) { + for (int64 i = 0; i < input_rank; ++i) { const int64 pad0 = static_cast(paddings_data(i, 0)); const int64 pad1 = static_cast(paddings_data(i, 1)); if (pad0 < 0 || pad1 < 0) { diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index bc99fb09e5e..adb1320fc70 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -1626,4 +1626,16 @@ TEST(ArrayOpsTest, QuantizedConcat_ShapeFn) { // Note that other cases of concat are covered in the Concat tests. } +TEST(StateOpsTest, _ParallelConcatStart_ShapeFn) { + ShapeInferenceTestOp op("_ParallelConcatStart"); + TensorShape shape({1, 2, 3}); + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + TF_ASSERT_OK(NodeDefBuilder("test", "_ParallelConcatStart") + .Attr("shape", shape_proto) + .Attr("dtype", DT_FLOAT) + .Finalize(&op.node_def)); + INFER_OK(op, "", "[1,2,3]"); +} + } // end namespace tensorflow diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 10b5df91f18..7e7d499f888 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -120,7 +120,7 @@ REGISTER_OP("DynamicStitch") TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions)); ShapeHandle extra_shape = c->UnknownShape(); - for (int i = 0; i < num_partitions; ++i) { + for (int64 i = 0; i < num_partitions; ++i) { ShapeHandle indices_shape = c->input(i); ShapeHandle data_shape = c->input(i + num_partitions); if (!c->RankKnown(indices_shape)) { diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc index 7b2da9d8e6d..392ac320103 100644 --- a/tensorflow/core/ops/random_ops.cc +++ b/tensorflow/core/ops/random_ops.cc @@ -23,17 +23,6 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -namespace { - -Status RandomShape(InferenceContext* c) { - ShapeHandle out; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); - c->set_output(0, out); - return Status::OK(); -} - -} // namepsace - REGISTER_OP("RandomUniform") .Input("shape: T") .SetIsStateful() @@ -42,7 +31,7 @@ REGISTER_OP("RandomUniform") .Attr("seed2: int = 0") .Attr("dtype: {half,float,double}") .Attr("T: {int32, int64}") - .SetShapeFn(RandomShape) + .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( Outputs random values from a uniform distribution. @@ -69,7 +58,7 @@ REGISTER_OP("RandomUniformInt") .Attr("seed2: int = 0") .Attr("Tout: {int32, int64}") .Attr("T: {int32, int64}") - .SetShapeFn(RandomShape) + .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( Outputs random integers from a uniform distribution. @@ -100,7 +89,7 @@ REGISTER_OP("RandomStandardNormal") .Attr("seed2: int = 0") .Attr("dtype: {half,float,double}") .Attr("T: {int32, int64}") - .SetShapeFn(RandomShape) + .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( Outputs random values from a normal distribution. @@ -128,7 +117,7 @@ REGISTER_OP("ParameterizedTruncatedNormal") .Attr("seed2: int = 0") .Attr("dtype: {half,float,double}") .Attr("T: {int32, int64}") - .SetShapeFn(RandomShape) + .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( Outputs random values from a normal distribution. The parameters may each be a scalar which applies to the entire output, or a vector of length shape[0] which @@ -158,7 +147,7 @@ REGISTER_OP("TruncatedNormal") .Attr("seed2: int = 0") .Attr("dtype: {half,float,double}") .Attr("T: {int32, int64}") - .SetShapeFn(RandomShape) + .SetShapeFn(shape_inference::RandomShape) .Doc(R"doc( Outputs random values from a truncated normal distribution. diff --git a/tensorflow/core/ops/set_ops.cc b/tensorflow/core/ops/set_ops.cc index fad70072071..85d1335dcf9 100644 --- a/tensorflow/core/ops/set_ops.cc +++ b/tensorflow/core/ops/set_ops.cc @@ -235,7 +235,7 @@ REGISTER_OP("SparseToSparseSetOperation") DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0); DimensionHandle output_rank_dim; if (c->ValueKnown(input0_rank_dim)) { - const int32 input0_rank = c->Value(input0_rank_dim); + const int64 input0_rank = c->Value(input0_rank_dim); if (input0_rank < 2) { return errors::InvalidArgument("Input 0, expected rank >= 2, got ", input0_rank, "."); @@ -244,7 +244,7 @@ REGISTER_OP("SparseToSparseSetOperation") c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim)); output_rank_dim = input0_rank_dim; } else if (c->ValueKnown(input1_rank_dim)) { - const int32 input1_rank = c->Value(input1_rank_dim); + const int64 input1_rank = c->Value(input1_rank_dim); if (input1_rank < 2) { return errors::InvalidArgument("Input 1, expected rank >= 2, got ", input1_rank, "."); diff --git a/tensorflow/core/platform/cpu_info.cc b/tensorflow/core/platform/cpu_info.cc index 9edf2de64ca..906826e6f83 100644 --- a/tensorflow/core/platform/cpu_info.cc +++ b/tensorflow/core/platform/cpu_info.cc @@ -68,7 +68,7 @@ int GetXCR0EAX() { // Structure for basic CPUID info class CPUIDInfo { -public: + public: CPUIDInfo() : have_adx_(0), have_aes_(0), diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 6e9eff62254..63821cb55ef 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -10,6 +10,15 @@ message RewriterConfig { bool optimize_tensor_layout = 1; bool disable_model_pruning = 2; bool constant_folding = 3; + + enum MemOptType { + // Fully disabled + NO_MEM_OPT = 0; + // Driven by manual annotations + MANUAL = 1; + } + MemOptType memory_optimization = 4; + // If non-empty, will use this as an alternative way to specify a list of // optimizations to turn on and the order of the optimizations. repeated string optimizers = 100; diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 8bb4ca8ff84..8a3f6c587ed 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -128,6 +128,28 @@ __device__ __host__ inline T ldg(const T* address) { #endif } +template <> +__device__ __host__ inline std::complex ldg( + const std::complex* address) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 + float2 mem = __ldg(reinterpret_cast(address)); + return std::complex(mem.x, mem.y); +#else + return *address; +#endif +} + +template <> +__device__ __host__ inline std::complex ldg( + const std::complex* address) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 + double2 mem = __ldg(reinterpret_cast(address)); + return std::complex(mem.x, mem.y); +#else + return *address; +#endif +} + // CUDA provides atomic ops, but not for all types. We provide wrappers // for some ops and provide implementation for all reasonable types. #define CUDA_ATOMIC_WRAPPER(op, T) \ diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index b8989b2c3ed..80a910e6890 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -249,8 +249,10 @@ BundleWriter::BundleWriter(Env* env, StringPiece prefix) random::New64())), out_(nullptr), size_(0) { - status_ = - env_->CreateDir(io::Dirname(prefix_).ToString()); // Ignores errors. + status_ = env_->CreateDir(io::Dirname(prefix_).ToString()); + if (!status_.ok() && !errors::IsAlreadyExists(status_)) { + return; + } const string filename = DataFilename(prefix_, 0, 1); std::unique_ptr wrapper; status_ = env_->NewWritableFile(tmp_data_path_, &wrapper); @@ -264,9 +266,9 @@ BundleWriter::BundleWriter(Env* env, StringPiece prefix) BundleWriter::~BundleWriter() { CHECK(out_ == nullptr); } Status BundleWriter::Add(StringPiece key, const Tensor& val) { + if (!status_.ok()) return status_; CHECK_NE(key, kHeaderEntryKey); const string key_string = key.ToString(); - if (!status_.ok()) return status_; if (entries_.find(key_string) != entries_.end()) { status_ = errors::InvalidArgument("Adding duplicate key: ", key); return status_; @@ -301,14 +303,14 @@ Status BundleWriter::AddSlice(StringPiece full_tensor_key, const TensorShape& full_tensor_shape, const TensorSlice& slice_spec, const Tensor& slice_tensor) { + if (!status_.ok()) return status_; + CHECK_NE(full_tensor_key, kHeaderEntryKey); + // If just a singleton full slice, use the regular Add() to be more efficient. if (IsFullSlice(slice_spec, full_tensor_shape)) { return Add(full_tensor_key, slice_tensor); } - CHECK_NE(full_tensor_key, kHeaderEntryKey); - if (!status_.ok()) return status_; - // Inserts/updates the full tensor's metadata entry. // // In the case of a sharded save, MergeBundles() is responsible for merging @@ -516,7 +518,8 @@ Status MergeBundles(Env* env, gtl::ArraySlice prefixes, // Merges all metadata tables. // TODO(zhifengc): KeyValue sorter if it becomes too big. MergeState merge; - env->CreateDir(io::Dirname(merged_prefix).ToString()).IgnoreError(); + Status status = env->CreateDir(io::Dirname(merged_prefix).ToString()); + if (!status.ok() && !errors::IsAlreadyExists(status)) return status; for (int i = 0; i < prefixes.size(); ++i) { TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge)); } @@ -534,7 +537,6 @@ Status MergeBundles(Env* env, gtl::ArraySlice prefixes, std::unique_ptr merged_metadata; TF_RETURN_IF_ERROR( env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata)); - Status status; { table::TableBuilder builder(table::Options(), merged_metadata.get()); // Header entry. diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index bca3910f59c..676bfe4df69 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -100,6 +100,10 @@ extern const int kTensorBundleVersion; extern const char* const kHeaderEntryKey; // Builds a string-string table of tensor names to BundleEntryProto (metadata). +// +// On construction, attempts to create a directory given by the dirname of +// "prefix", so "status()" must be checked before calling any member functions. +// // All threads accessing the same BundleWriter must synchronize. class BundleWriter { public: diff --git a/tensorflow/docs_src/get_started/get_started.md b/tensorflow/docs_src/get_started/get_started.md index 6ae61b43a0e..6116c7d87fb 100644 --- a/tensorflow/docs_src/get_started/get_started.md +++ b/tensorflow/docs_src/get_started/get_started.md @@ -323,6 +323,10 @@ When run, it produces W: [-0.9999969] b: [ 0.99999082] loss: 5.69997e-11 ``` +Notice that the loss is a very small number (close to zero). If you run this +program your loss will not be exactly the same, because the model is initialized +with random values. + This more complicated program can still be visualized in TensorBoard ![TensorBoard final model visualization](../images/getting_started_final.png) diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index 6f442e6e0c4..10eebf6f42a 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -130,6 +130,8 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at | `lo -r hidden/Relu:0` | List the recipients of the output of the node `hidden/Relu`, recursively—i.e., the output recipient tree. | | `lt -n softmax.*` | List all dumped tensors whose names match the regular-expression pattern `softmax.*`. | | `lt -t MatMul` | List all dumped tensors whose node type is `MatMul`. | +| `ls` | List all Python source files responsible for constructing the nodes (and tensors) in the current graph. | +| `ls -n softmax.*` | List Python source files responsible for constructing the nodes whose names match the pattern `softmax.*`. | | `ps /path/to/source.py` | Print the Python source file source.py, with the lines annotated with the ops created at each of them, respectively. | | `ps -t /path/to/source.py` | Same as the command above, but perform annotation using dumped Tensors, instead of ops. | | `ps -b 30 /path/to/source.py` | Annotate source.py beginning at line 30. | diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/recurrent.md index a1c0532f5a2..8cc6cf15ef8 100644 --- a/tensorflow/docs_src/tutorials/recurrent.md +++ b/tensorflow/docs_src/tutorials/recurrent.md @@ -173,15 +173,22 @@ final_state = state ## Run the Code -Start by cloning the [TensorFlow models repo](https://github.com/tensorflow/models) from GitHub. -You'll also need to download the PTB dataset, as discussed at the beginning of -this tutorial; we'll assume the dataset is located in `/tmp/simple-examples/data`. +Before running the code, download the PTB dataset, as discussed at the beginning +of this tutorial. Then, extract the PTB dataset underneath your home directory +as follows: -Run the following commands: +```bsh +tar xvfz simple-examples.tgz -C $HOME +``` +_(Note: On Windows, you may need to use +[other tools](https://wiki.haskell.org/How_to_unpack_a_tar_file_in_Windows).)_ -```bash +Now, clone the [TensorFlow models repo](https://github.com/tensorflow/models) +from GitHub. Run the following commands: + +```bsh cd models/tutorials/rnn/ptb -python ptb_word_lm.py --data_path=/tmp/simple-examples/data/ --model=small +python ptb_word_lm.py --data_path=$HOME/simple-examples/data/ --model=small ``` There are 3 supported model configurations in the tutorial code: "small", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 6900ac9a4f4..5b50df3ed34 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1094,8 +1094,9 @@ class BaseSession(SessionInterface): if tensors_to_delete: feeds = {} fetches = [] - for tensor_handle in tensors_to_delete: + for deleter_key, tensor_handle in enumerate(tensors_to_delete): holder, deleter = session_ops._get_handle_deleter(self.graph, + deleter_key, tensor_handle) feeds[holder] = tensor_handle fetches.append(deleter) diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 99c154bd997..930eb5f283f 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -375,6 +375,8 @@ Status GetPyArrayDescrForTensor(const TF_Tensor* tensor, PyObject* fields = PyList_New(1); PyList_SetItem(fields, 0, field); int convert_result = PyArray_DescrConverter(fields, descr); + Py_CLEAR(field); + Py_CLEAR(fields); if (convert_result != 1) { return errors::Internal("Failed to create numpy array description for ", "TF_RESOURCE-type tensor"); diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 56dd7ceba52..0b87660e5dd 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -185,6 +185,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base_ui", + ":cli_shared", ":command_parser", ":curses_widgets", ":debugger_cli_common", @@ -375,6 +376,7 @@ py_test( ":debug_utils", ":source_utils", "//tensorflow/python:client", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py index 0c8004e2545..95d3f3f2498 100644 --- a/tensorflow/python/debug/cli/analyzer_cli.py +++ b/tensorflow/python/debug/cli/analyzer_cli.py @@ -345,6 +345,25 @@ class DebugAnalyzer(object): help="Print source beginning at line number (1-based.)") self._arg_parsers["print_source"] = ap + # Parser for list_source. + ap = argparse.ArgumentParser( + description="List source files responsible for constructing nodes and " + "tensors present in the run().", + usage=argparse.SUPPRESS) + ap.add_argument( + "-p", + "--path_filter", + type=str, + default="", + help="Regular expression filter for file path.") + ap.add_argument( + "-n", + "--node_name_filter", + type=str, + default="", + help="Regular expression filter for node name.") + self._arg_parsers["list_source"] = ap + # TODO(cais): Implement list_nodes. def add_tensor_filter(self, filter_name, filter_callable): @@ -979,6 +998,15 @@ class DebugAnalyzer(object): return output + def _reconstruct_print_source_command(self, + parsed, + line_begin_decrease=0, + max_elements_per_line_increase=0): + return "ps %s %s -b %d -m %d" % ( + parsed.source_file_path, "-t" if parsed.tensors else "", + max(parsed.line_begin - line_begin_decrease, 1), + parsed.max_elements_per_line + max_elements_per_line_increase) + def print_source(self, args, screen_info=None): """Print the content of a source file.""" del screen_info # Unused. @@ -1000,12 +1028,20 @@ class DebugAnalyzer(object): labeled_source_lines = [] if parsed.line_begin > 1: - labeled_source_lines.append( - RL("(... Omitted %d source lines ...)" % (parsed.line_begin - 1), - "bold")) + omitted_info_line = RL( + "(... Omitted %d source lines ...) " % (parsed.line_begin - 1), + "bold") + omitted_info_line += RL( + "+5", + debugger_cli_common.MenuItem( + None, + self._reconstruct_print_source_command( + parsed, line_begin_decrease=5))) + labeled_source_lines.append(omitted_info_line) for i, line in enumerate(source_lines[parsed.line_begin - 1:]): - annotated_line = RL("L%d" % (i + parsed.line_begin), "yellow") + annotated_line = RL("L%d" % (i + parsed.line_begin), + cli_shared.COLOR_YELLOW) annotated_line += " " * (line_num_width - len(annotated_line)) annotated_line += line labeled_source_lines.append(annotated_line) @@ -1014,11 +1050,17 @@ class DebugAnalyzer(object): sorted_elements = sorted(source_annotation[i + parsed.line_begin]) for k, element in enumerate(sorted_elements): if k >= parsed.max_elements_per_line: - labeled_source_lines.append( - " (... Omitted %d of %d %s ...)" % ( - len(sorted_elements) - parsed.max_elements_per_line, - len(sorted_elements), - "tensor(s)" if parsed.tensors else "op(s)")) + omitted_info_line = RL(" (... Omitted %d of %d %s ...) " % ( + len(sorted_elements) - parsed.max_elements_per_line, + len(sorted_elements), + "tensor(s)" if parsed.tensors else "op(s)")) + omitted_info_line += RL( + "+5", + debugger_cli_common.MenuItem( + None, + self._reconstruct_print_source_command( + parsed, max_elements_per_line_increase=5))) + labeled_source_lines.append(omitted_info_line) break label = RL(" " * 4) @@ -1026,7 +1068,7 @@ class DebugAnalyzer(object): debug_data.get_node_name(element)): attribute = debugger_cli_common.MenuItem("", "pt %s" % element) else: - attribute = "blue" + attribute = cli_shared.COLOR_BLUE label += RL(element, attribute) labeled_source_lines.append(label) @@ -1036,6 +1078,105 @@ class DebugAnalyzer(object): _add_main_menu(output, node_name=None) return output + def _make_source_table(self, source_list, is_tf_py_library): + """Make a table summarizing the source files that create nodes and tensors. + + Args: + source_list: List of source files and related information as a list of + tuples (file_path, is_tf_library, num_nodes, num_tensors, num_dumps, + first_line). + is_tf_py_library: (`bool`) whether this table is for files that belong + to the TensorFlow Python library. + + Returns: + The table as a `debugger_cli_common.RichTextLines` object. + """ + path_head = "Source file path" + num_nodes_head = "#(nodes)" + num_tensors_head = "#(tensors)" + num_dumps_head = "#(tensor dumps)" + + if is_tf_py_library: + # Use color to mark files that are guessed to belong to TensorFlow Python + # library. + color = cli_shared.COLOR_GRAY + lines = [RL("TensorFlow Python library file(s):", color)] + else: + color = cli_shared.COLOR_WHITE + lines = [RL("File(s) outside TensorFlow Python library:", color)] + + if not source_list: + lines.append(RL("[No files.]")) + lines.append(RL()) + return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) + + path_column_width = max( + max([len(item[0]) for item in source_list]), len(path_head)) + 1 + num_nodes_column_width = max( + max([len(str(item[2])) for item in source_list]), + len(num_nodes_head)) + 1 + num_tensors_column_width = max( + max([len(str(item[3])) for item in source_list]), + len(num_tensors_head)) + 1 + + head = RL(path_head + " " * (path_column_width - len(path_head)), color) + head += RL(num_nodes_head + " " * ( + num_nodes_column_width - len(num_nodes_head)), color) + head += RL(num_tensors_head + " " * ( + num_tensors_column_width - len(num_tensors_head)), color) + head += RL(num_dumps_head, color) + + lines.append(head) + + for item in source_list: + path_attributes = [debugger_cli_common.MenuItem( + None, "ps %s -b %d" % (item[0], item[5])), color] + + line = RL(item[0], path_attributes) + line += " " * (path_column_width - len(line)) + line += RL( + str(item[2]) + " " * (num_nodes_column_width - len(str(item[2]))), + color) + line += RL( + str(item[3]) + " " * (num_tensors_column_width - len(str(item[3]))), + color) + line += RL(str(item[4]), color) + lines.append(line) + lines.append(RL()) + + return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) + + def list_source(self, args, screen_info=None): + """List Python source files that constructed nodes and tensors.""" + del screen_info # Unused. + + parsed = self._arg_parsers["list_source"].parse_args(args) + source_list = source_utils.list_source_files_against_dump( + self._debug_dump, + path_regex_whitelist=parsed.path_filter, + node_name_regex_whitelist=parsed.node_name_filter) + + top_lines = [ + RL("List of source files that created nodes in this run", "bold")] + if parsed.path_filter: + top_lines.append( + RL("File path regex filter: \"%s\"" % parsed.path_filter)) + if parsed.node_name_filter: + top_lines.append( + RL("Node name regex filter: \"%s\"" % parsed.node_name_filter)) + top_lines.append(RL()) + output = debugger_cli_common.rich_text_lines_from_rich_line_list(top_lines) + if not source_list: + output.append("[No source file information.]") + return output + + output.extend(self._make_source_table( + [item for item in source_list if not item[1]], False)) + output.extend(self._make_source_table( + [item for item in source_list if item[1]], True)) + _add_main_menu(output, node_name=None) + return output + def _list_inputs_or_outputs(self, recursive, node_name, @@ -1395,6 +1536,11 @@ def create_analyzer_ui(debug_dump, analyzer.print_source, analyzer.get_help("print_source"), prefix_aliases=["ps"]) + cli.register_command_handler( + "list_source", + analyzer.list_source, + analyzer.get_help("list_source"), + prefix_aliases=["ls"]) dumped_tensor_names = [] for datum in debug_dump.dumped_tensor_data: diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index bb2d72e2e4b..185d395126b 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -28,6 +28,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.cli import analyzer_cli +from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.lib import debug_data @@ -569,6 +570,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): cls._analyzer.print_source, cls._analyzer.get_help("print_source"), prefix_aliases=["ps"]) + cls._registry.register_command_handler( + "list_source", + cls._analyzer.list_source, + cls._analyzer.get_help("list_source"), + prefix_aliases=["ls"]) @classmethod def tearDownClass(cls): @@ -906,7 +912,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): ["ERROR: There is no node named \"bar\" in the partition graphs"], out.lines) # Check color indicating error. - self.assertEqual({0: [(0, 59, "red")]}, out.font_attr_segs) + self.assertEqual({0: [(0, 59, cli_shared.COLOR_RED)]}, out.font_attr_segs) check_main_menu(self, out, list_tensors_enabled=True) def testPrintTensor(self): @@ -1172,7 +1178,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): out.font_attr_segs[index + 1][0][2].content) # simple_mul_add/u/Assign is not used in this run because the Variable has # already been initialized. - self.assertEqual("blue", out.font_attr_segs[index + 2][0][2]) + self.assertEqual(cli_shared.COLOR_BLUE, out.font_attr_segs[index + 2][0][2]) self.assertEqual("pt simple_mul_add/u/read", out.font_attr_segs[index + 3][0][2].content) @@ -1234,6 +1240,12 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): screen_info={"cols": 80}) self.assertIn("Omitted 2 source lines", out.lines[0]) + self.assertTrue(out.lines[0].endswith("+5")) + expand_lines_command = out.font_attr_segs[0][-1][2].content + self.assertStartsWith(expand_lines_command, + "ps %s " % self._curr_file_path) + self.assertIn("-b 1", expand_lines_command) + self.assertIsNone(self._findSourceLine(out, 1)) self.assertIsNone(self._findSourceLine(out, 2)) self.assertIsNotNone(self._findSourceLine(out, 3)) @@ -1250,7 +1262,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): out.font_attr_segs[index + 1][0][2].content) # simple_mul_add/u/Assign is not used in this run because the Variable has # already been initialized. - self.assertEqual("blue", out.font_attr_segs[index + 2][0][2]) + self.assertEqual(cli_shared.COLOR_BLUE, out.font_attr_segs[index + 2][0][2]) self.assertEqual("pt simple_mul_add/u/read", out.font_attr_segs[index + 3][0][2].content) @@ -1266,10 +1278,81 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): ["L%d u = variables.Variable(u_init, name=u_name)" % self._u_line_number, " simple_mul_add/u", - " (... Omitted 2 of 3 op(s) ...)"], + " (... Omitted 2 of 3 op(s) ...) +5"], out.lines[index : index + 3]) self.assertEqual("pt simple_mul_add/u", out.font_attr_segs[index + 1][0][2].content) + more_elements_command = out.font_attr_segs[index + 2][-1][2].content + self.assertStartsWith(more_elements_command, + "ps %s " % self._curr_file_path) + self.assertIn(" -m 6", more_elements_command) + + def testListSourceWorks(self): + self._debug_dump.set_python_graph(self._sess.graph) + out = self._registry.dispatch_command("list_source", []) + + non_tf_lib_files_start = [ + i for i in xrange(len(out.lines)) + if out.lines[i].startswith("Source file path")][0] + 1 + non_tf_lib_files_end = [ + i for i in xrange(len(out.lines)) + if out.lines[i].startswith("TensorFlow Python library file(s):")][0] - 1 + non_tf_lib_files = [ + line.split(" ")[0] for line + in out.lines[non_tf_lib_files_start : non_tf_lib_files_end]] + self.assertIn(self._curr_file_path, non_tf_lib_files) + + # Check that the TF library files are marked with special color attribute. + for i in xrange(non_tf_lib_files_end + 1, len(out.lines)): + if not out.lines[i]: + continue + for attr_seg in out.font_attr_segs[i]: + self.assertTrue(cli_shared.COLOR_GRAY in attr_seg[2] or + attr_seg[2] == cli_shared.COLOR_GRAY) + + def testListSourceWithNodeNameFilterWithMatchesWorks(self): + self._debug_dump.set_python_graph(self._sess.graph) + out = self._registry.dispatch_command("list_source", ["-n", ".*/read"]) + + self.assertStartsWith(out.lines[1], "Node name regex filter: \".*/read\"") + + non_tf_lib_files_start = [ + i for i in xrange(len(out.lines)) + if out.lines[i].startswith("Source file path")][0] + 1 + non_tf_lib_files_end = [ + i for i in xrange(len(out.lines)) + if out.lines[i].startswith("TensorFlow Python library file(s):")][0] - 1 + non_tf_lib_files = [ + line.split(" ")[0] for line + in out.lines[non_tf_lib_files_start : non_tf_lib_files_end]] + self.assertIn(self._curr_file_path, non_tf_lib_files) + + # Check that the TF library files are marked with special color attribute. + for i in xrange(non_tf_lib_files_end + 1, len(out.lines)): + if not out.lines[i]: + continue + for attr_seg in out.font_attr_segs[i]: + self.assertTrue(cli_shared.COLOR_GRAY in attr_seg[2] or + attr_seg[2] == cli_shared.COLOR_GRAY) + + def testListSourceWithNodeNameFilterWithNoMatchesWorks(self): + self._debug_dump.set_python_graph(self._sess.graph) + out = self._registry.dispatch_command("list_source", ["-n", "^$"]) + + self.assertEqual([ + "List of source files that created nodes in this run", + "Node name regex filter: \"^$\"", "", + "[No source file information.]"], out.lines) + + def testListSourceWithPathAndNodeNameFiltersWorks(self): + self._debug_dump.set_python_graph(self._sess.graph) + out = self._registry.dispatch_command( + "list_source", ["-p", self._curr_file_path, "-n", ".*read"]) + + self.assertEqual([ + "List of source files that created nodes in this run", + "File path regex filter: \"%s\"" % self._curr_file_path, + "Node name regex filter: \".*read\"", ""], out.lines[:4]) class AnalyzerCLIPrintLargeTensorTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index b1953479502..8ff09167614 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -32,6 +32,16 @@ RL = debugger_cli_common.RichLine # when printing the value of the tensor. DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000 +COLOR_BLACK = "black" +COLOR_BLUE = "blue" +COLOR_CYAN = "cyan" +COLOR_GRAY = "gray" +COLOR_GREEN = "green" +COLOR_MAGENTA = "magenta" +COLOR_RED = "red" +COLOR_WHITE = "white" +COLOR_YELLOW = "yellow" + def bytes_to_readable_str(num_bytes, include_b=False): """Generate a human-readable string representing number of bytes. @@ -154,7 +164,7 @@ def error(msg): """ return debugger_cli_common.rich_text_lines_from_rich_line_list([ - RL("ERROR: " + msg, "red")]) + RL("ERROR: " + msg, COLOR_RED)]) def _get_fetch_name(fetch): diff --git a/tensorflow/python/debug/cli/curses_ui.py b/tensorflow/python/debug/cli/curses_ui.py index d8d3bce3de7..b7549b406b6 100644 --- a/tensorflow/python/debug/cli/curses_ui.py +++ b/tensorflow/python/debug/cli/curses_ui.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import curses from curses import textpad +import os import signal import sys import threading @@ -27,6 +28,7 @@ import threading from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.debug.cli import base_ui +from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import curses_widgets from tensorflow.python.debug.cli import debugger_cli_common @@ -42,6 +44,9 @@ _SCROLL_HOME = "home" _SCROLL_END = "end" _SCROLL_TO_LINE_INDEX = "scroll_to_line_index" +_COLOR_READY_COLORTERMS = ["gnome-terminal", "xfce4-terminal"] +_COLOR_ENABLED_TERM = "xterm-256color" + def _get_command_from_line_attr_segs(mouse_x, attr_segs): """Attempt to extract command from the attribute segments of a line. @@ -77,7 +82,7 @@ class ScrollBar(object): event in the screen region it occupies. """ - BASE_ATTR = "black_on_white" + BASE_ATTR = cli_shared.COLOR_BLACK + "_on_" + cli_shared.COLOR_WHITE def __init__(self, min_x, @@ -225,27 +230,36 @@ class CursesUI(base_ui.BaseUI): } _FOREGROUND_COLORS = { - "white": curses.COLOR_WHITE, - "red": curses.COLOR_RED, - "green": curses.COLOR_GREEN, - "yellow": curses.COLOR_YELLOW, - "blue": curses.COLOR_BLUE, - "cyan": curses.COLOR_CYAN, - "magenta": curses.COLOR_MAGENTA, - "black": curses.COLOR_BLACK, + cli_shared.COLOR_WHITE: curses.COLOR_WHITE, + cli_shared.COLOR_RED: curses.COLOR_RED, + cli_shared.COLOR_GREEN: curses.COLOR_GREEN, + cli_shared.COLOR_YELLOW: curses.COLOR_YELLOW, + cli_shared.COLOR_BLUE: curses.COLOR_BLUE, + cli_shared.COLOR_CYAN: curses.COLOR_CYAN, + cli_shared.COLOR_MAGENTA: curses.COLOR_MAGENTA, + cli_shared.COLOR_BLACK: curses.COLOR_BLACK, } _BACKGROUND_COLORS = { - "white": curses.COLOR_WHITE, - "black": curses.COLOR_BLACK, + "transparent": -1, + cli_shared.COLOR_WHITE: curses.COLOR_WHITE, + cli_shared.COLOR_BLACK: curses.COLOR_BLACK, } # Font attribute for search and highlighting. - _SEARCH_HIGHLIGHT_FONT_ATTR = "black_on_white" - _ARRAY_INDICES_COLOR_PAIR = "black_on_white" - _ERROR_TOAST_COLOR_PAIR = "red_on_white" - _INFO_TOAST_COLOR_PAIR = "blue_on_white" - _STATUS_BAR_COLOR_PAIR = "black_on_white" - _UI_WAIT_COLOR_PAIR = "magenta_on_white" + _SEARCH_HIGHLIGHT_FONT_ATTR = ( + cli_shared.COLOR_BLACK + "_on_" + cli_shared.COLOR_WHITE) + _ARRAY_INDICES_COLOR_PAIR = ( + cli_shared.COLOR_BLACK + "_on_" + cli_shared.COLOR_WHITE) + _ERROR_TOAST_COLOR_PAIR = ( + cli_shared.COLOR_RED + "_on_" + cli_shared.COLOR_WHITE) + _INFO_TOAST_COLOR_PAIR = ( + cli_shared.COLOR_BLUE + "_on_" + cli_shared.COLOR_WHITE) + _STATUS_BAR_COLOR_PAIR = ( + cli_shared.COLOR_BLACK + "_on_" + cli_shared.COLOR_WHITE) + _UI_WAIT_COLOR_PAIR = ( + cli_shared.COLOR_MAGENTA + "_on_" + cli_shared.COLOR_WHITE) + _NAVIGATION_WARNING_COLOR_PAIR = ( + cli_shared.COLOR_RED + "_on_" + cli_shared.COLOR_WHITE) _UI_WAIT_MESSAGE = "Processing..." @@ -370,29 +384,43 @@ class CursesUI(base_ui.BaseUI): Creates curses stdscr and initialize the color pairs for display. """ - + # If the terminal type is color-ready, enable it. + if os.getenv("COLORTERM") in _COLOR_READY_COLORTERMS: + os.environ["TERM"] = _COLOR_ENABLED_TERM self._stdscr = curses.initscr() self._command_window = None + self._screen_color_init() - # Prepare color pairs. + def _screen_color_init(self): + """Initialization of screen colors.""" curses.start_color() - + curses.use_default_colors() self._color_pairs = {} color_index = 0 + # Prepare color pairs. for fg_color in self._FOREGROUND_COLORS: for bg_color in self._BACKGROUND_COLORS: - color_index += 1 curses.init_pair(color_index, self._FOREGROUND_COLORS[fg_color], self._BACKGROUND_COLORS[bg_color]) color_name = fg_color - if bg_color != "black": + if bg_color != "transparent": color_name += "_on_" + bg_color self._color_pairs[color_name] = curses.color_pair(color_index) + # Try getting color(s) available only under 256-color support. + try: + color_index += 1 + curses.init_pair(color_index, 245, -1) + self._color_pairs[cli_shared.COLOR_GRAY] = curses.color_pair(color_index) + except curses.error: + # Use fall-back color(s): + self._color_pairs[cli_shared.COLOR_GRAY] = ( + self._color_pairs[cli_shared.COLOR_GREEN]) + # A_BOLD or A_BLINK is not really a "color". But place it here for # convenience. self._color_pairs["bold"] = curses.A_BOLD @@ -400,7 +428,7 @@ class CursesUI(base_ui.BaseUI): self._color_pairs["underline"] = curses.A_UNDERLINE # Default color pair to use when a specified color pair does not exist. - self._default_color_pair = self._color_pairs["white"] + self._default_color_pair = self._color_pairs[cli_shared.COLOR_WHITE] def _screen_launch(self, enable_mouse_on_start): """Launch the curses screen.""" @@ -588,7 +616,7 @@ class CursesUI(base_ui.BaseUI): scroll_position = item.scroll_position else: self._toast("At the LATEST in navigation history!", - color="red_on_white") + color=self._NAVIGATION_WARNING_COLOR_PAIR) return else: if self._nav_history.can_go_back(): @@ -596,7 +624,7 @@ class CursesUI(base_ui.BaseUI): scroll_position = item.scroll_position else: self._toast("At the OLDEST in navigation history!", - color="red_on_white") + color=self._NAVIGATION_WARNING_COLOR_PAIR) return self._display_output(item.screen_output) @@ -959,7 +987,7 @@ class CursesUI(base_ui.BaseUI): self._curr_wrapped_output.lines.append("Output cut off at %d lines!" % self.max_output_lines) self._curr_wrapped_output.font_attr_segs[self.max_output_lines] = [ - (0, len(output.lines[-1]), "magenta") + (0, len(output.lines[-1]), cli_shared.COLOR_MAGENTA) ] self._display_nav_bar() @@ -1518,7 +1546,9 @@ class CursesUI(base_ui.BaseUI): pad, _, _ = self._display_lines( debugger_cli_common.RichTextLines( - message, font_attr_segs={0: [(0, len(message), color or "white")]}), + message, + font_attr_segs={ + 0: [(0, len(message), color or cli_shared.COLOR_WHITE)]}), 0) right_end = min(len(message), self._max_x - 2) diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py index aee08498321..94eb2754da2 100644 --- a/tensorflow/python/debug/cli/stepper_cli.py +++ b/tensorflow/python/debug/cli/stepper_cli.py @@ -68,19 +68,19 @@ class NodeStepperCLI(object): _UPDATED_ATTRIBUTE = "bold" _STATE_COLORS = { - STATE_CONT: "green", - STATE_DIRTY_VARIABLE: "magenta", - STATE_DUMPED_INTERMEDIATE: "blue", - STATE_OVERRIDDEN: "yellow", - STATE_IS_PLACEHOLDER: "cyan", - STATE_UNFEEDABLE: "red", + STATE_CONT: cli_shared.COLOR_GREEN, + STATE_DIRTY_VARIABLE: cli_shared.COLOR_MAGENTA, + STATE_DUMPED_INTERMEDIATE: cli_shared.COLOR_BLUE, + STATE_OVERRIDDEN: cli_shared.COLOR_YELLOW, + STATE_IS_PLACEHOLDER: cli_shared.COLOR_CYAN, + STATE_UNFEEDABLE: cli_shared.COLOR_RED, } _FEED_COLORS = { - stepper.NodeStepper.FEED_TYPE_CLIENT: "white", - stepper.NodeStepper.FEED_TYPE_HANDLE: "green", - stepper.NodeStepper.FEED_TYPE_OVERRIDE: "yellow", - stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE: "blue", + stepper.NodeStepper.FEED_TYPE_CLIENT: cli_shared.COLOR_WHITE, + stepper.NodeStepper.FEED_TYPE_HANDLE: cli_shared.COLOR_GREEN, + stepper.NodeStepper.FEED_TYPE_OVERRIDE: cli_shared.COLOR_YELLOW, + stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE: cli_shared.COLOR_BLUE, } def __init__(self, node_stepper): diff --git a/tensorflow/python/debug/lib/source_utils.py b/tensorflow/python/debug/lib/source_utils.py index cc949932cb1..b8a5daf860b 100644 --- a/tensorflow/python/debug/lib/source_utils.py +++ b/tensorflow/python/debug/lib/source_utils.py @@ -18,13 +18,47 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os +import re + +_TENSORFLOW_BASEDIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname( + os.path.normpath(os.path.abspath(__file__)))))) def _convert_watch_key_to_tensor_name(watch_key): return watch_key[:watch_key.rfind(":")] +def _guess_is_tensorflow_py_library(py_file_path): + """Guess whether a Python source file is a part of the tensorflow library. + + Special cases: + 1) Returns False for unit-test files in the library (*_test.py), + 2) Returns False for files under python/debug/examples. + + Args: + py_file_path: full path of the Python source file in question. + + Returns: + (`bool`) Whether the file is a part of the tensorflow library. + + Raises: + ValueError: if py_file_path does not end with ".py". + """ + + if not py_file_path.endswith(".py"): + raise ValueError( + "Input file path (%s) is not a Python source file." % py_file_path) + py_file_path = os.path.normpath(os.path.abspath(py_file_path)) + + return (py_file_path.startswith(_TENSORFLOW_BASEDIR) and + not py_file_path.endswith("_test.py") and + not os.path.dirname(py_file_path).endswith( + os.path.normpath("python/debug/examples"))) + + def annotate_source(dump, source_file_path, do_dumped_tensors=False, @@ -61,21 +95,16 @@ def annotate_source(dump, raise ValueError("Cannot perform source annotation due to a lack of set " "Python graph in the dump object") - source_file_path = os.path.normpath(source_file_path) + source_file_path = os.path.normpath(os.path.abspath(source_file_path)) line_to_op_names = {} for op in py_graph.get_operations(): - try: - traceback = dump.node_traceback(op.name) - except KeyError: - pass - - for file_path, line_number, _, _ in reversed(traceback): + for file_path, line_number, _, _ in reversed(dump.node_traceback(op.name)): if (min_line is not None and line_number < min_line or max_line is not None and line_number >= max_line): continue - if os.path.normpath(file_path) != source_file_path: + if os.path.normpath(os.path.abspath(file_path)) != source_file_path: continue if do_dumped_tensors: @@ -95,3 +124,103 @@ def annotate_source(dump, break return line_to_op_names + + +def list_source_files_against_dump(dump, + path_regex_whitelist=None, + node_name_regex_whitelist=None): + """Generate a list of source files with information regarding ops and tensors. + + Args: + dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph + has been loaded. + path_regex_whitelist: A regular-expression filter for source file path. + node_name_regex_whitelist: A regular-expression filter for node names. + + Returns: + A list of tuples regarding the Python source files involved in constructing + the ops and tensors contained in `dump`. Each tuple is: + (source_file_path, is_tf_library, num_nodes, num_tensors, num_dumps, + first_line) + + is_tf_library: (`bool`) A guess of whether the file belongs to the + TensorFlow Python library. + num_nodes: How many nodes were created by lines of this source file. + These include nodes with dumps and those without. + num_tensors: How many Tensors were created by lines of this source file. + These include Tensors with dumps and those without. + num_dumps: How many debug Tensor dumps were from nodes (and Tensors) + that were created by this source file. + first_line: The first line number (1-based) that created any nodes or + Tensors in this source file. + + The list is sorted by ascending order of source_file_path. + + Raises: + ValueError: If the dump object does not have a Python graph set. + """ + + py_graph = dump.python_graph + if not py_graph: + raise ValueError("Cannot generate source list due to a lack of set " + "Python graph in the dump object") + + path_to_node_names = collections.defaultdict(set) + path_to_tensor_names = collections.defaultdict(set) + path_to_first_line = {} + tensor_name_to_num_dumps = {} + + path_regex = (re.compile(path_regex_whitelist) + if path_regex_whitelist else None) + node_name_regex = (re.compile(node_name_regex_whitelist) + if node_name_regex_whitelist else None) + + to_skip_file_paths = set() + for op in py_graph.get_operations(): + if node_name_regex and not node_name_regex.match(op.name): + continue + + for file_path, line_number, _, _ in dump.node_traceback(op.name): + file_path = os.path.normpath(os.path.abspath(file_path)) + if (file_path in to_skip_file_paths or + path_regex and not path_regex.match(file_path) or + not os.path.isfile(file_path)): + to_skip_file_paths.add(file_path) + continue + + path_to_node_names[file_path].add(op.name) + if file_path in path_to_first_line: + if path_to_first_line[file_path] > line_number: + path_to_first_line[file_path] = line_number + else: + path_to_first_line[file_path] = line_number + + for output_tensor in op.outputs: + tensor_name = output_tensor.name + path_to_tensor_names[file_path].add(tensor_name) + + watch_keys = dump.debug_watch_keys(op.name) + for watch_key in watch_keys: + node_name, output_slot, debug_op = watch_key.split(":") + tensor_name = "%s:%s" % (node_name, output_slot) + if tensor_name not in tensor_name_to_num_dumps: + tensor_name_to_num_dumps[tensor_name] = len( + dump.get_tensors(node_name, int(output_slot), debug_op)) + + path_to_num_dumps = {} + for path in path_to_tensor_names: + path_to_num_dumps[path] = sum( + tensor_name_to_num_dumps.get(tensor_name, 0) + for tensor_name in path_to_tensor_names[path]) + + output = [] + for file_path in path_to_node_names: + output.append(( + file_path, + _guess_is_tensorflow_py_library(file_path), + len(path_to_node_names.get(file_path, {})), + len(path_to_tensor_names.get(file_path, {})), + path_to_num_dumps.get(file_path, 0), + path_to_first_line[file_path])) + + return sorted(output, key=lambda x: x[0]) diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py index 5d28bff2072..138c75de31c 100644 --- a/tensorflow/python/debug/lib/source_utils_test.py +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -33,6 +33,7 @@ from tensorflow.python.debug.lib import source_utils from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -42,6 +43,37 @@ def line_number_above(): return inspect.stack()[1][2] - 1 +class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.curr_file_path = os.path.normpath(os.path.abspath(__file__)) + + def tearDown(self): + ops.reset_default_graph() + + def testGuessedBaseDirIsProbablyCorrect(self): + self.assertEqual( + "tensorflow", os.path.basename(source_utils._TENSORFLOW_BASEDIR)) + + def testUnitTestFileReturnsFalse(self): + self.assertFalse(source_utils._guess_is_tensorflow_py_library( + self.curr_file_path)) + + def _disabledtestSourceUtilModuleReturnsTrue(self): + self.assertTrue(source_utils._guess_is_tensorflow_py_library( + source_utils.__file__)) + + def testFileInPythonKernelsPathReturnsTrue(self): + x = constant_op.constant(42.0, name="x") + self.assertTrue(source_utils._guess_is_tensorflow_py_library( + x.op.traceback[-1][0])) + + def testNonPythonFileRaisesException(self): + with self.assertRaisesRegexp(ValueError, r"is not a Python source file"): + source_utils._guess_is_tensorflow_py_library( + os.path.join(os.path.dirname(self.curr_file_path), "foo.cc")) + + class SourceHelperTest(test_util.TensorFlowTestCase): def createAndRunGraphHelper(self): @@ -199,5 +231,131 @@ class SourceHelperTest(test_util.TensorFlowTestCase): os.remove(unrelated_source_path) +class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase): + + def createAndRunGraphWithWhileLoop(self): + """Create and run a TensorFlow Graph with a while loop to generate dumps.""" + + self.dump_root = self.get_temp_dir() + self.curr_file_path = os.path.abspath( + inspect.getfile(inspect.currentframe())) + + # Run a simple TF graph to generate some debug dumps that can be used in + # source annotation. + with session.Session() as sess: + loop_body = lambda i: math_ops.add(i, 2) + self.traceback_first_line = line_number_above() + + loop_cond = lambda i: math_ops.less(i, 16) + + i = constant_op.constant(10, name="i") + loop = control_flow_ops.while_loop(loop_cond, loop_body, [i]) + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + debug_utils.watch_graph( + run_options, sess.graph, debug_urls=["file://%s" % self.dump_root]) + run_metadata = config_pb2.RunMetadata() + sess.run(loop, options=run_options, run_metadata=run_metadata) + + self.dump = debug_data.DebugDumpDir( + self.dump_root, partition_graphs=run_metadata.partition_graphs) + self.dump.set_python_graph(sess.graph) + + def setUp(self): + self.createAndRunGraphWithWhileLoop() + + def tearDown(self): + if os.path.isdir(self.dump_root): + shutil.rmtree(self.dump_root) + ops.reset_default_graph() + + def testGenerateSourceList(self): + source_list = source_utils.list_source_files_against_dump(self.dump) + + # Assert that the file paths are sorted and unique. + file_paths = [item[0] for item in source_list] + self.assertEqual(sorted(file_paths), file_paths) + self.assertEqual(len(set(file_paths)), len(file_paths)) + + # Assert that each item of source_list has length 6. + for item in source_list: + self.assertTrue(isinstance(item, tuple)) + self.assertEqual(6, len(item)) + + # The while loop body should have executed 3 times. The following table + # lists the tensors and how many times each of them is dumped. + # Tensor name # of times dumped: + # i:0 1 + # while/Enter:0 1 + # while/Merge:0 4 + # while/Merge:1 4 + # while/Less/y:0 4 + # while/Less:0 4 + # while/LoopCond:0 4 + # while/Switch:0 1 + # while/Swtich:1 3 + # while/Identity:0 3 + # while/Add/y:0 3 + # while/Add:0 3 + # while/NextIteration:0 3 + # while/Exit:0 1 + # ---------------------------- + # (Total) 39 + # + # The total number of nodes is 12. + # The total number of tensors is 14 (2 of the nodes have 2 outputs: + # while/Merge, while/Switch). + + _, is_tf_py_library, num_nodes, num_tensors, num_dumps, first_line = ( + source_list[file_paths.index(self.curr_file_path)]) + self.assertFalse(is_tf_py_library) + self.assertEqual(12, num_nodes) + self.assertEqual(14, num_tensors) + self.assertEqual(39, num_dumps) + self.assertEqual(self.traceback_first_line, first_line) + + def testGenerateSourceListWithNodeNameFilter(self): + source_list = source_utils.list_source_files_against_dump( + self.dump, node_name_regex_whitelist=r"while/Add.*") + + # Assert that the file paths are sorted. + file_paths = [item[0] for item in source_list] + self.assertEqual(sorted(file_paths), file_paths) + self.assertEqual(len(set(file_paths)), len(file_paths)) + + # Assert that each item of source_list has length 4. + for item in source_list: + self.assertTrue(isinstance(item, tuple)) + self.assertEqual(6, len(item)) + + # Due to the node-name filtering the result should only contain 2 nodes + # and 2 tensors. The total number of dumped tensors should be 6: + # while/Add/y:0 3 + # while/Add:0 3 + _, is_tf_py_library, num_nodes, num_tensors, num_dumps, _ = ( + source_list[file_paths.index(self.curr_file_path)]) + self.assertFalse(is_tf_py_library) + self.assertEqual(2, num_nodes) + self.assertEqual(2, num_tensors) + self.assertEqual(6, num_dumps) + + def testGenerateSourceListWithPathRegexFilter(self): + curr_file_basename = os.path.basename(self.curr_file_path) + source_list = source_utils.list_source_files_against_dump( + self.dump, + path_regex_whitelist=( + ".*" + curr_file_basename.replace(".", "\\.") + "$")) + + self.assertEqual(1, len(source_list)) + (file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps, + first_line) = source_list[0] + self.assertEqual(self.curr_file_path, file_path) + self.assertFalse(is_tf_py_library) + self.assertEqual(12, num_nodes) + self.assertEqual(14, num_tensors) + self.assertEqual(39, num_dumps) + self.assertEqual(self.traceback_first_line, first_line) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 616b7ae49b1..f1471a515f0 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -108,18 +108,19 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":checkpoint_utils", ":export", ":model_fn", ":run_config", "//tensorflow/core:protos_all_py", + "//tensorflow/python:client", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:metrics", "//tensorflow/python:platform", + "//tensorflow/python:random_seed", "//tensorflow/python:summary", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:tag_constants", ], @@ -131,20 +132,31 @@ py_test( srcs_version = "PY2AND3", deps = [ ":estimator", + ":export", ":model_fn", ":numpy_io", ":run_config", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", "//tensorflow/python:layers", + "//tensorflow/python:lib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", "//tensorflow/python:saver_test_utils", "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", ], ) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 36918af5529..80c5bbf6848 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -141,6 +141,11 @@ class Estimator(object): logging.info('Using config: %s', str(vars(self._config))) + if self._config.session_config is None: + self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + self._session_config = self._config.session_config + self._device_fn = _get_replica_device_setter(self._config) if model_fn is None: @@ -317,7 +322,7 @@ class Estimator(object): session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=estimator_spec.scaffold, - config=config_pb2.ConfigProto(allow_soft_placement=True)), + config=self._session_config), hooks=hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) @@ -552,7 +557,8 @@ class Estimator(object): training.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, - defer_build=True)) + defer_build=True, + save_relative_paths=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or @@ -579,7 +585,7 @@ class Estimator(object): chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, - config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess: + config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) @@ -634,7 +640,7 @@ class Estimator(object): eval_ops=update_op, final_ops=eval_dict, hooks=hooks, - config=config_pb2.ConfigProto(allow_soft_placement=True)) + config=self._session_config) _write_dict_to_summary( output_dir=eval_dir, @@ -643,12 +649,6 @@ class Estimator(object): return eval_results - def _verify_default_metric_key(self, metric_key, eval_dict): - if metric_key in six.iterkeys(eval_dict): - raise ValueError( - 'Metric with name `%s` is not allowed, because Estimator ' - 'already defines a default metric with the same name.' % metric_key) - def _check_hooks_type(hooks): """Returns hooks if all are SessionRunHook, raises TypeError otherwise.""" diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index a1659156a62..84813073d35 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -23,6 +23,8 @@ import tempfile import numpy as np +from google.protobuf import text_format + from tensorflow.python.client import session from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib @@ -34,6 +36,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import layers +from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -48,6 +51,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import checkpoint_state_pb2 from tensorflow.python.training import saver from tensorflow.python.training import saver_test_utils from tensorflow.python.training import session_run_hook @@ -236,6 +240,40 @@ class EstimatorTrainTest(test.TestCase): self.assertEqual( 5, estimator._load_global_step_from_checkpoint_dir(est.model_dir)) + def test_checkpoint_contains_relative_paths(self): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator( + model_dir=tmpdir, + model_fn=model_fn_global_step_incrementer) + est.train(dummy_input_fn, steps=5) + + checkpoint_file_content = file_io.read_file_to_string( + os.path.join(tmpdir, 'checkpoint')) + ckpt = checkpoint_state_pb2.CheckpointState() + text_format.Merge(checkpoint_file_content, ckpt) + self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') + self.assertAllEqual( + ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) + + def test_train_save_copy_reload(self): + tmpdir = tempfile.mkdtemp() + model_dir1 = os.path.join(tmpdir, 'model_dir1') + est1 = estimator.Estimator( + model_dir=model_dir1, + model_fn=model_fn_global_step_incrementer) + est1.train(dummy_input_fn, steps=5) + + model_dir2 = os.path.join(tmpdir, 'model_dir2') + os.renames(model_dir1, model_dir2) + est2 = estimator.Estimator( + model_dir=model_dir2, + model_fn=model_fn_global_step_incrementer) + self.assertEqual( + 5, estimator._load_global_step_from_checkpoint_dir(est2.model_dir)) + est2.train(dummy_input_fn, steps=5) + self.assertEqual( + 10, estimator._load_global_step_from_checkpoint_dir(est2.model_dir)) + def test_steps0_raises_error(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops) diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index c6e6c609917..79b55c68532 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -72,6 +72,10 @@ class RunConfig(object): def save_checkpoints_secs(self): return 600 + @property + def session_config(self): + return None + @property def save_checkpoints_steps(self): return None diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index dac8d58b356..1f161e59cd0 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -31,61 +31,80 @@ from tensorflow.python.platform import test class GatherTest(test.TestCase): use_gpu = False + def _buildParams(self, data, dtype): + data = data.astype(dtype.as_numpy_dtype) + # For complex types, add an index-dependent imaginary component so we can + # tell we got the right value. + if dtype.is_complex: + return data + 10j * data + return data + def testScalar1D(self): with self.test_session(use_gpu=self.use_gpu): - params = constant_op.constant([0, 1, 2, 3, 7, 5]) - indices = constant_op.constant(4) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual(7, gather_val) - self.assertEqual([], gather_t.get_shape()) + data = np.array([0, 1, 2, 3, 7, 5]) + for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128): + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + indices = constant_op.constant(4) + gather_t = array_ops.gather(params, indices) + gather_val = gather_t.eval() + self.assertAllEqual(params_np[4], gather_val) + self.assertEqual([], gather_t.get_shape()) def testScalar2D(self): with self.test_session(use_gpu=self.use_gpu): - params = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], - [9, 10, 11], [12, 13, 14]]) - indices = constant_op.constant(2) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual([6, 7, 8], gather_val) - self.assertEqual([3], gather_t.get_shape()) + data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], + [9, 10, 11], [12, 13, 14]]) + for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128): + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + indices = constant_op.constant(2) + gather_t = array_ops.gather(params, indices) + gather_val = gather_t.eval() + self.assertAllEqual(params_np[2], gather_val) + self.assertEqual([3], gather_t.get_shape()) def testSimpleTwoD32(self): with self.test_session(use_gpu=self.use_gpu): - params = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], - [9, 10, 11], [12, 13, 14]]) - indices = constant_op.constant([0, 4, 0, 2]) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual([[0, 1, 2], [12, 13, 14], [0, 1, 2], [6, 7, 8]], - gather_val) - self.assertEqual([4, 3], gather_t.get_shape()) + data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], + [9, 10, 11], [12, 13, 14]]) + for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128): + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + indices = constant_op.constant([0, 4, 0, 2]) + gather_t = array_ops.gather(params, indices) + gather_val = gather_t.eval() + self.assertAllEqual(params_np[[0, 4, 0, 2]], gather_val) + self.assertEqual([4, 3], gather_t.get_shape()) def testHigherRank(self): np.random.seed(1) # We check that scalar and empty shapes work as well for shape in (7, 0), (4, 3, 2): for indices_shape in (), (0,), (3, 0), (3, 5): - params = np.random.randn(*shape) - indices = np.random.randint(shape[0], size=indices_shape) - with self.test_session(use_gpu=self.use_gpu): - tf_params = constant_op.constant(params) - tf_indices = constant_op.constant(indices) - gather = array_ops.gather(tf_params, tf_indices) - self.assertAllEqual(params[indices], gather.eval()) - self.assertEqual(indices.shape + params.shape[1:], gather.get_shape()) - # Test gradients - gather_grad = np.random.randn(*gather.get_shape().as_list()) - params_grad, indices_grad = gradients_impl.gradients( - gather, [tf_params, tf_indices], gather_grad) - self.assertEqual(indices_grad, None) - self.assertEqual(type(params_grad), ops.IndexedSlices) - params_grad = ops.convert_to_tensor(params_grad) - correct_params_grad = np.zeros(shape) - for i, g in zip(indices.flat, - gather_grad.reshape((indices.size,) + shape[1:])): - correct_params_grad[i] += g - self.assertAllClose(correct_params_grad, params_grad.eval()) + for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128): + params = self._buildParams(np.random.randn(*shape), dtype) + indices = np.random.randint(shape[0], size=indices_shape) + with self.test_session(use_gpu=self.use_gpu): + tf_params = constant_op.constant(params) + tf_indices = constant_op.constant(indices) + gather = array_ops.gather(tf_params, tf_indices) + self.assertAllEqual(params[indices], gather.eval()) + self.assertEqual(indices.shape + params.shape[1:], + gather.get_shape()) + # Test gradients + gather_grad = np.random.randn(*gather.get_shape().as_list()).astype( + dtype.as_numpy_dtype) + params_grad, indices_grad = gradients_impl.gradients( + gather, [tf_params, tf_indices], gather_grad) + self.assertEqual(indices_grad, None) + self.assertEqual(type(params_grad), ops.IndexedSlices) + params_grad = ops.convert_to_tensor(params_grad) + correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) + for i, g in zip(indices.flat, + gather_grad.reshape((indices.size,) + shape[1:])): + correct_params_grad[i] += g + self.assertAllClose(correct_params_grad, params_grad.eval()) def testUnknownIndices(self): params = constant_op.constant([[0, 1, 2]]) @@ -103,7 +122,7 @@ class GatherTest(test.TestCase): def testEmptySlices(self): with self.test_session(use_gpu=self.use_gpu): - for dtype in np.float32, np.float64: + for dtype in np.float32, np.float64, np.complex64, np.complex128: for itype in np.int32, np.int64: params = np.zeros((7, 0), dtype=dtype) indices = np.array([3, 4], dtype=itype) diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py index ff299e65116..153d4ab6623 100644 --- a/tensorflow/python/kernel_tests/linalg_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg_ops_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorflow.python.ops.special_math_ops.""" +"""Tests for tensorflow.python.ops.linalg_ops.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 86593828345..c998f57da7e 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -69,6 +69,19 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access trainable: Boolean, if `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). name: A string, the name of the layer. + renorm: Whether to use Batch Renormalization + (https://arxiv.org/abs/1702.03275). This adds extra variables during + training. The inference is the same for either value of this parameter. + renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to + scalar `Tensors` used to clip the renorm correction. The correction + `(r, d)` is used as `corrected_value = normalized_value * r + d`, with + `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, + dmax are set to inf, 0, inf, respectively. + renorm_momentum: Momentum used to update the moving means and standard + deviations with renorm. Unlike `momentum`, this affects training + and should be neither too small (which would add noise) nor too large + (which would give stale estimates). Note that `momentum` is still applied + to get the means and variances for inference. """ def __init__(self, @@ -85,6 +98,9 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access gamma_regularizer=None, trainable=True, name=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, **kwargs): super(BatchNormalization, self).__init__( name=name, trainable=trainable, **kwargs) @@ -99,6 +115,15 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access self.moving_variance_initializer = moving_variance_initializer self.beta_regularizer = beta_regularizer self.gamma_regularizer = gamma_regularizer + self.renorm = renorm + if renorm: + renorm_clipping = renorm_clipping or {} + keys = ['rmax', 'rmin', 'dmax'] + if set(renorm_clipping) - set(keys): + raise ValueError('renorm_clipping %s contains keys not in %s' % + (renorm_clipping, keys)) + self.renorm_clipping = renorm_clipping + self.renorm_momentum = renorm_momentum def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) @@ -148,9 +173,90 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access shape=(param_dim,), initializer=self.moving_variance_initializer, trainable=False) + if self.renorm: + # Create variables to maintain the moving mean and standard deviation. + # These are used in training and thus are different from the moving + # averages above. The renorm variables are colocated with moving_mean + # and moving_variance. + # NOTE: below, the outer `with device` block causes the current device + # stack to be cleared. The nested ones use a `lambda` to set the desired + # device and ignore any devices that may be set by the custom getter. + def _renorm_variable(name, shape): + var = vs.get_variable(name, + shape=shape, + initializer=init_ops.zeros_initializer(), + trainable=False) + return var + with ops.device(None): + with ops.device(lambda _: self.moving_mean.device): + self.renorm_mean = _renorm_variable('renorm_mean', (param_dim,)) + self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) + # We initialize renorm_stddev to 0, and maintain the (0-initialized) + # renorm_stddev_weight. This allows us to (1) mix the average + # stddev with the minibatch stddev early in training, and (2) compute + # the unbiased average stddev by dividing renorm_stddev by the weight. + with ops.device(lambda _: self.moving_variance.device): + self.renorm_stddev = _renorm_variable('renorm_stddev', (param_dim,)) + self.renorm_stddev_weight = _renorm_variable( + 'renorm_stddev_weight', ()) finally: vs.get_variable_scope().set_partitioner(partitioner) + def _renorm_correction_and_moments(self, mean, variance, training): + """Returns the correction and update values for renorm.""" + stddev = math_ops.sqrt(variance + self.epsilon) + # Compute the average mean and standard deviation, as if they were + # initialized with this batch's moments. + mixed_renorm_mean = (self.renorm_mean + + (1. - self.renorm_mean_weight) * mean) + mixed_renorm_stddev = (self.renorm_stddev + + (1. - self.renorm_stddev_weight) * stddev) + # Compute the corrections for batch renorm. + r = stddev / mixed_renorm_stddev + d = (mean - mixed_renorm_mean) / mixed_renorm_stddev + # Ensure the corrections use pre-update moving averages. + with ops.control_dependencies([r, d]): + mean = array_ops.identity(mean) + stddev = array_ops.identity(stddev) + rmin, rmax, dmax = [self.renorm_clipping.get(key) + for key in ['rmin', 'rmax', 'dmax']] + if rmin is not None: + r = math_ops.maximum(r, rmin) + if rmax is not None: + r = math_ops.minimum(r, rmax) + if dmax is not None: + d = math_ops.maximum(d, -dmax) + d = math_ops.minimum(d, dmax) + # When not training, use r=1, d=0, and decay=1 meaning no updates. + r = _smart_select(training, lambda: r, lambda: array_ops.ones_like(r)) + d = _smart_select(training, lambda: d, lambda: array_ops.zeros_like(d)) + decay = _smart_select(training, lambda: self.renorm_momentum, lambda: 1.) + def _update_renorm_variable(var, weight, value): + """Updates a moving average and weight, returns the unbiased value.""" + # Update the variables without zero debiasing. The debiasing will be + # accomplished by dividing the exponential moving average by the weight. + # For example, after a single update, the moving average would be + # (1-decay) * value. and the weight will be 1-decay, with their ratio + # giving value. + new_var = moving_averages.assign_moving_average( + var, value, decay, zero_debias=False) + new_weight = moving_averages.assign_moving_average( + weight, 1., decay, zero_debias=False) + return new_var / new_weight + + with ops.colocate_with(self.moving_mean): + new_mean = _update_renorm_variable(self.renorm_mean, + self.renorm_mean_weight, + mean) + with ops.colocate_with(self.moving_variance): + new_stddev = _update_renorm_variable(self.renorm_stddev, + self.renorm_stddev_weight, + stddev) + # Make sqrt(moving_variance + epsilon) = new_stddev. + new_variance = math_ops.square(new_stddev) - self.epsilon + + return (r, d, new_mean, new_variance) + def call(self, inputs, training=False): # First, compute the axes along which to reduce the mean / variance, # as well as the broadcast shape to be used for all parameters. @@ -164,82 +270,66 @@ class BatchNormalization(base._Layer): # pylint: disable=protected-access # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) + scale, offset = self.gamma, self.beta + # Determine a boolean value for `training`: could be True, False, or None. training_value = utils.constant_value(training) - - if needs_broadcasting: - # In this case we must explictly broadcast all parameters. - if self.center: - broadcast_beta = array_ops.reshape(self.beta, broadcast_shape) - else: - broadcast_beta = None - if self.scale: - broadcast_gamma = array_ops.reshape(self.gamma, broadcast_shape) - else: - broadcast_gamma = None - if training_value is not False: - if needs_broadcasting: - broadcast_mean, broadcast_variance = nn.moments( - inputs, reduction_axes, keep_dims=True) - mean = array_ops.reshape(broadcast_mean, [-1]) - variance = array_ops.reshape(broadcast_variance, [-1]) + # Some of the computations here are not necessary when training==False + # but not a constant. However, this makes the code simpler. + mean, variance = nn.moments(inputs, reduction_axes) + if self.renorm: + r, d, new_mean, new_variance = self._renorm_correction_and_moments( + mean, variance, training) + # When training, the normalized values (say, x) will be transformed as + # x * gamma + beta without renorm, and (x * r + d) * gamma + beta + # = x * (r * gamma) + (d * gamma + beta) with renorm. + scale = array_ops.stop_gradient(r, name='renorm_r') + offset = array_ops.stop_gradient(d, name='renorm_d') + if self.gamma is not None: + scale *= self.gamma + offset *= self.gamma + if self.beta is not None: + offset += self.beta else: - mean, variance = nn.moments(inputs, reduction_axes) + new_mean, new_variance = mean, variance + + # Update moving averages when training, and prevent updates otherwise. + decay = _smart_select(training, lambda: self.momentum, lambda: 1.) + mean_update = moving_averages.assign_moving_average( + self.moving_mean, new_mean, decay, zero_debias=False) + variance_update = moving_averages.assign_moving_average( + self.moving_variance, new_variance, decay, zero_debias=False) - # Prepare updates if necessary. if not self.updates: - mean_update = moving_averages.assign_moving_average( - self.moving_mean, mean, self.momentum, zero_debias=False) - variance_update = moving_averages.assign_moving_average( - self.moving_variance, variance, self.momentum, zero_debias=False) # In the future this should be refactored into a self.add_update # methods in order to allow for instance-based BN layer sharing # across unrelated input streams (e.g. like in Keras). self.updates.append(mean_update) self.updates.append(variance_update) - # Normalize batch. We do this inside separate functions for training - # and inference so as to avoid evaluating both branches. - def normalize_in_test(): - if needs_broadcasting: - broadcast_moving_mean = array_ops.reshape(self.moving_mean, - broadcast_shape) - broadcast_moving_variance = array_ops.reshape(self.moving_variance, - broadcast_shape) - return nn.batch_normalization(inputs, - broadcast_moving_mean, - broadcast_moving_variance, - broadcast_beta, - broadcast_gamma, - self.epsilon) - else: - return nn.batch_normalization(inputs, - self.moving_mean, - self.moving_variance, - self.beta if self.center else None, - self.gamma if self.scale else None, - self.epsilon) + mean = _smart_select(training, + lambda: mean, + lambda: self.moving_mean) + variance = _smart_select(training, + lambda: variance, + lambda: self.moving_variance) - def normalize_in_training(): - if needs_broadcasting: - return nn.batch_normalization(inputs, - broadcast_mean, - broadcast_variance, - broadcast_beta, - broadcast_gamma, - self.epsilon) - else: - return nn.batch_normalization(inputs, - mean, - variance, - self.beta if self.center else None, - self.gamma if self.scale else None, - self.epsilon) + else: + mean, variance = self.moving_mean, self.moving_variance - return utils.smart_cond(training, - normalize_in_training, - normalize_in_test) + def _broadcast(v): + if needs_broadcasting and v is not None: + # In this case we must explictly broadcast all parameters. + return array_ops.reshape(v, broadcast_shape) + return v + + return nn.batch_normalization(inputs, + _broadcast(mean), + _broadcast(variance), + _broadcast(offset), + _broadcast(scale), + self.epsilon) def batch_normalization(inputs, @@ -257,7 +347,10 @@ def batch_normalization(inputs, training=False, trainable=True, name=None, - reuse=None): + reuse=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99): """Functional interface for the batch normalization layer. Reference: http://arxiv.org/abs/1502.03167 @@ -294,6 +387,19 @@ def batch_normalization(inputs, name: String, the name of the layer. reuse: Boolean, whether to reuse the weights of a previous layer by the same name. + renorm: Whether to use Batch Renormalization + (https://arxiv.org/abs/1702.03275). This adds extra variables during + training. The inference is the same for either value of this parameter. + renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to + scalar `Tensors` used to clip the renorm correction. The correction + `(r, d)` is used as `corrected_value = normalized_value * r + d`, with + `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, + dmax are set to inf, 0, inf, respectively. + renorm_momentum: Momentum used to update the moving means and standard + deviations with renorm. Unlike `momentum`, this affects training + and should be neither too small (which would add noise) nor too large + (which would give stale estimates). Note that `momentum` is still applied + to get the means and variances for inference. Returns: Output tensor. @@ -311,6 +417,9 @@ def batch_normalization(inputs, beta_regularizer=beta_regularizer, gamma_regularizer=gamma_regularizer, trainable=trainable, + renorm=renorm, + renorm_clipping=renorm_clipping, + renorm_momentum=renorm_momentum, name=name, _reuse=reuse, _scope=name) @@ -321,3 +430,39 @@ def batch_normalization(inputs, BatchNorm = BatchNormalization batch_norm = batch_normalization + + +# Helper function + + +def _smart_select(pred, fn_then, fn_else): + """Selects fn_then() or fn_else() based on the value of pred. + + The purpose of this function is the same as `utils.smart_cond`. However, at + the moment there is a bug (b/36297356) that seems to kick in only when + `smart_cond` delegates to `tf.cond`, which sometimes results in the training + hanging when using parameter servers. This function will output the result + of `fn_then` or `fn_else` if `pred` is known at graph construction time. + Otherwise, it will use `tf.where` which will result in some redundant work + (both branches will be computed but only one selected). However, the tensors + involved will usually be small (means and variances in batchnorm), so the + cost will be small and will not be incurred at all if `pred` is a constant. + + Args: + pred: A boolean scalar `Tensor`. + fn_then: A callable to use when pred==True. + fn_else: A callable to use when pred==False. + + Returns: + A `Tensor` whose value is fn_then() or fn_else() based on the value of pred. + """ + pred_value = utils.constant_value(pred) + if pred_value: + return fn_then() + elif pred_value is False: + return fn_else() + t_then = array_ops.expand_dims(fn_then(), 0) + t_else = array_ops.expand_dims(fn_else(), 0) + pred = array_ops.reshape(pred, [1]) + result = array_ops.where(pred, t_then, t_else) + return array_ops.squeeze(result, [0]) diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index 91b7cb6f483..0f82f73ea48 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope @@ -513,6 +514,64 @@ class BNTest(test.TestCase): _ = bn.apply(inputs, training=training) self.assertEqual(len(bn.losses), 1) + def testRenorm(self): + shape = (4, 3) + xt = array_ops.placeholder(dtypes.float32, shape) + momentum = 0.99 + renorm_momentum = 0.8 + rmax = 1.1 + rmin = 0.9 + dmax = 0.1 + gamma = 2. + beta = 3. + epsilon = 0.001 + bn = normalization_layers.BatchNormalization( + axis=1, + gamma_initializer=init_ops.constant_initializer(gamma), + beta_initializer=init_ops.constant_initializer(beta), + epsilon=epsilon, + momentum=momentum, + renorm=True, + renorm_clipping={'rmax': rmax, 'rmin': rmin, 'dmax': dmax}, + renorm_momentum=renorm_momentum) + training = array_ops.placeholder(dtypes.bool) + yt = bn.apply(xt, training=training) + + moving_mean = 0. + moving_variance = 1. + renorm_mean = renorm_stddev = 0. + renorm_weight = 0. + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + for _ in range(5): + x = np.random.random(shape) + + mean = x.mean(0) + stddev = np.sqrt(x.var(0) + epsilon) + adj_mean = renorm_mean + (1. - renorm_weight) * mean + adj_stddev = renorm_stddev + (1. - renorm_weight) * stddev + r = (stddev / adj_stddev).clip(rmin, rmax) + d = ((mean - adj_mean) / adj_stddev).clip(-dmax, dmax) + y_train = ((x - mean) / stddev * r + d) * gamma + beta + renorm_mean += (mean - renorm_mean) * (1. - renorm_momentum) + renorm_stddev += (stddev - renorm_stddev) * (1. - renorm_momentum) + renorm_weight += (1. - renorm_weight) * (1. - renorm_momentum) + moving_mean += (renorm_mean / renorm_weight - + moving_mean) * (1. - momentum) + moving_variance += ((renorm_stddev / renorm_weight) ** 2 - epsilon - + moving_variance) * (1. - momentum) + + y_test = ((x - moving_mean) / (moving_variance + epsilon) ** 0.5 * + gamma) + beta + + yt_val_train, _, _ = sess.run([yt] + bn.updates, + feed_dict={xt: x, training: True}) + yt_val_test, _, _ = sess.run([yt] + bn.updates, + feed_dict={xt: x, training: False}) + + self.assertAllClose(y_train, yt_val_train, atol=1e-5) + self.assertAllClose(y_test, yt_val_test, atol=1e-5) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 60057b9ab1e..45efc51d5c9 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -196,7 +196,7 @@ def broadcast_dynamic_shape(shape_x, shape_y): Args: shape_x: A rank 1 integer `Tensor`, representing the shape of x. - shape_y: A rank 1 integer `Tensor`, representing the shape of x. + shape_y: A rank 1 integer `Tensor`, representing the shape of y. Returns: A rank 1 integer `Tensor` representing the broadcasted shape. """ @@ -1292,6 +1292,17 @@ def matrix_transpose(a, name="matrix_transpose"): # tf.matrix_transpose(x) is shape [1, 2, 4, 3] ``` + Note that `tf.matmul` provides kwargs allowing for transpose of arguments. + This is done with minimal cost, and is preferable to using this function. E.g. + + ``` + # Good! Transpose is taken at minimal additional cost. + tf.matmul(matrix, b, transpose_b=True) + + # Inefficient! + tf.matmul(matrix, tf.matrix_transpose(b)) + ``` + Args: a: A `Tensor` with `rank >= 2`. name: A name for the operation (optional). diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 99d29a37193..66ccedf546e 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -375,7 +375,9 @@ def with_space_to_batch( input_shape_list = input.get_shape().as_list() input_spatial_shape = [input_shape_list[i] for i in spatial_dims] if input_spatial_shape is None or None in input_spatial_shape: - input_spatial_shape = array_ops.gather(array_ops.shape(input), spatial_dims) + input_shape_tensor = array_ops.shape(input) + input_spatial_shape = array_ops.stack( + [input_shape_tensor[i] for i in spatial_dims]) paddings, crops = array_ops.required_space_to_batch_paddings( input_shape=input_spatial_shape, @@ -2021,7 +2023,7 @@ def top_k(input, k=1, sorted=True, name=None): def conv1d(value, filters, stride, padding, use_cudnn_on_gpu=None, data_format=None, name=None): - """Computes a 1-D convolution given 3-D input and filter tensors. + r"""Computes a 1-D convolution given 3-D input and filter tensors. Given an input tensor of shape [batch, in_width, in_channels] diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 86e0cae27ac..77f0468c017 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -197,8 +197,10 @@ class ResourceVariable(object): self._initialize_op = gen_resource_variable_ops.assign_variable_op( self._handle, self._initial_value, name=n) with ops.name_scope("Read"), ops.colocate_with(self._handle): - value = gen_resource_variable_ops.read_variable_op( - self._handle, dtype=self._dtype) + # Manually assign reads to the handle's device to avoid log messages. + with ops.device(self._handle.device): + value = gen_resource_variable_ops.read_variable_op( + self._handle, dtype=self._dtype) self._graph_element = value if caching_device is not None: # Variables may be created in a tf.device() or ops.colocate_with() @@ -276,8 +278,9 @@ class ResourceVariable(object): """A cached operation which reads the value of this variable.""" if self._cached_value is not None: return self._cached_value - return gen_resource_variable_ops.read_variable_op( - self._handle, dtype=self._dtype) + with ops.device(self._handle.device): + return gen_resource_variable_ops.read_variable_op( + self._handle, dtype=self._dtype) def _as_graph_element(self): """Conversion function for Graph.as_graph_element().""" @@ -318,8 +321,9 @@ class ResourceVariable(object): the read operation. """ with ops.name_scope("Read"): - value = gen_resource_variable_ops.read_variable_op( - self._handle, dtype=self._dtype) + with ops.device(self._handle.device): + value = gen_resource_variable_ops.read_variable_op( + self._handle, dtype=self._dtype) # Return an identity so it can get placed on whatever device the context # specifies instead of the device where the variable is. return array_ops.identity(value) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 162b13ec212..1051478a7f7 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -37,6 +37,36 @@ _state_size_with_prefix = rnn_cell_impl._state_size_with_prefix # pylint: enable=protected-access +def _transpose_batch_time(x): + """Transpose the batch and time dimensions of a Tensor. + + Retains as much of the static shape information as possible. + + Args: + x: A tensor of rank 2 or higher. + + Returns: + x transposed along the first two dimensions. + + Raises: + ValueError: if `x` is rank 1 or lower. + """ + x_static_shape = x.get_shape() + if x_static_shape.ndims is not None and x_static_shape.ndims < 2: + raise ValueError( + "Expected input tensor %s to have rank at least 2, but saw shape: %s" % + (x, x_static_shape)) + x_rank = array_ops.rank(x) + x_t = array_ops.transpose( + x, array_ops.concat( + ([1, 0], math_ops.range(2, x_rank)), axis=0)) + x_t.set_shape( + tensor_shape.TensorShape([ + x_static_shape[1].value, x_static_shape[0].value + ]).concatenate(x_static_shape[2:])) + return x_t + + def _infer_state_dtype(explicit_dtype, state): """Infer the dtype of an RNN state. @@ -492,8 +522,8 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, if not time_major: # (B,T,D) => (T,B,D) - flat_input = tuple(array_ops.transpose(input_, [1, 0, 2]) - for input_ in flat_input) + flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] + flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) parallel_iterations = parallel_iterations or 32 if sequence_length is not None: @@ -556,11 +586,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # to shape [batch, time, depth] if not time_major: # (T,B,D) => (B,T,D) - flat_output = nest.flatten(outputs) - flat_output = [array_ops.transpose(output, [1, 0, 2]) - for output in flat_output] - outputs = nest.pack_sequence_as( - structure=outputs, flat_sequence=flat_output) + outputs = nest.map_structure(_transpose_batch_time, outputs) return (outputs, final_state) @@ -1003,34 +1029,20 @@ def raw_rnn(cell, loop_fn, def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" - current_flat = nest.flatten(current) - candidate_flat = nest.flatten(candidate) - # pylint: disable=g-long-lambda,cell-var-from-loop - result_flat = [ - _on_device( - lambda: array_ops.where( - elements_finished, current_i, candidate_i), - device=candidate_i.op.device) - for (current_i, candidate_i) in zip(current_flat, candidate_flat)] - # pylint: enable=g-long-lambda,cell-var-from-loop - return nest.pack_sequence_as( - structure=current, flat_sequence=result_flat) + def copy_fn(cur_i, cand_i): + return _on_device( + lambda: array_ops.where(elements_finished, cur_i, cand_i), + device=cand_i.op.device) + return nest.map_structure(copy_fn, current, candidate) emit_output = _copy_some_through(zero_emit, emit_output) next_state = _copy_some_through(state, next_state) - emit_output_flat = nest.flatten(emit_output) - emit_ta_flat = nest.flatten(emit_ta) + emit_ta = nest.map_structure( + lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) elements_finished = math_ops.logical_or(elements_finished, next_finished) - emit_ta_flat = [ - ta.write(time, emit) - for (ta, emit) in zip(emit_ta_flat, emit_output_flat)] - - emit_ta = nest.pack_sequence_as( - structure=emit_structure, flat_sequence=emit_ta_flat) - return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state) diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index 0a06982ad7c..3d038cfd8a0 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -116,7 +116,7 @@ class TensorHandle(object): raise TypeError("Persistent tensor %s may have already been deleted." % self.handle) self._auto_gc_enabled = False - holder, deleter = _get_handle_deleter(self._session.graph, self._handle) + holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle) self._session.run(deleter, feed_dict={holder: self.handle}) def get_raw_handle(self): @@ -141,11 +141,6 @@ class TensorHandle(object): handle_parts = str(handle).split(";") return handle_parts[0] + ";" + handle_parts[-1] - @staticmethod - def _get_deleter_key(handle): - """The graph key for deleter.""" - return str(handle).split(";")[-1] - @staticmethod def _get_mover_key(feeder, handle): """The graph key for mover.""" @@ -302,10 +297,9 @@ def _get_handle_mover(graph, feeder, handle): return result -def _get_handle_deleter(graph, handle): +def _get_handle_deleter(graph, deleter_key, handle): """Return a deletion subgraph for this handle.""" - graph_key = TensorHandle._get_deleter_key(handle) - result = graph._handle_deleters.get(graph_key) + result = graph._handle_deleters.get(deleter_key) if result is None: # Create deleter if we haven't done it. handle_device = TensorHandle._get_device_name(handle) @@ -313,5 +307,5 @@ def _get_handle_deleter(graph, handle): holder = array_ops.placeholder(dtypes.string) deleter = gen_data_flow_ops._delete_session_tensor(holder) result = (holder, deleter) - graph._handle_deleters[graph_key] = result + graph._handle_deleters[deleter_key] = result return result diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 2a64cb7b705..f46f56cbb71 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -243,8 +243,9 @@ def assign_add(ref, value, use_locking=None, name=None): def assign(ref, value, validate_shape=None, use_locking=None, name=None): """Update 'ref' by assigning 'value' to it. - This operation outputs "ref" after the assignment is done. - This makes it easier to chain operations that need to use the reset value. + This operation outputs a Tensor that holds the new value of 'ref' after + the value has been assigned. This makes it easier to chain operations + that need to use the reset value. Args: ref: A mutable `Tensor`. @@ -261,8 +262,8 @@ def assign(ref, value, validate_shape=None, use_locking=None, name=None): name: A name for the operation (optional). Returns: - Same as "ref". Returned as a convenience for operations that want - to use the new value after the variable has been reset. + A `Tensor` that will hold the new value of 'ref' after + the assignment has completed. """ if ref.dtype._is_ref_dtype: return gen_state_ops.assign( diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 19c5d3c3ea0..b3745fa4e69 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -974,6 +974,8 @@ class VariableScope(object): partitioner = self._partitioner if dtype is None: dtype = self._dtype + if use_resource is None: + use_resource = self._use_resource if self._custom_getter is not None: raise ValueError( diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index 111461f7842..4b0ef50df5d 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.training import optimizer @@ -154,7 +155,7 @@ class AdamOptimizer(optimizer.Optimizer): math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), grad, use_locking=self._use_locking) - def _apply_sparse(self, grad, var): + def _apply_sparse_shared(self, grad, var, indices, scatter_add): beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) @@ -164,23 +165,39 @@ class AdamOptimizer(optimizer.Optimizer): lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") - m_scaled_g_values = grad.values * (1 - beta1_t) + m_scaled_g_values = grad * (1 - beta1_t) m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) - m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values, - use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") - v_scaled_g_values = (grad.values * grad.values) * (1 - beta2_t) - v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) - v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values, - use_locking=self._use_locking) + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) v_sqrt = math_ops.sqrt(v_t) var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t]) + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, i, v, use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add( + x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, self._resource_scatter_add) + def _finish(self, update_ops, name_scope): # Update the power accumulators. with ops.control_dependencies(update_ops): diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index 00ff5d9b9d4..62b171e234e 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -52,7 +52,7 @@ def adam_update_numpy(param, class AdamOptimizerTest(test.TestCase): - def testSparse(self): + def doTestSparse(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.test_session(): # Initialize variables for numpy implementation. @@ -62,8 +62,12 @@ class AdamOptimizerTest(test.TestCase): var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) grads0_np_indices = np.array([0, 1], dtype=np.int32) grads0 = ops.IndexedSlices( constant_op.constant(grads0_np), @@ -95,6 +99,12 @@ class AdamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var1_np, var1.eval()) + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + def testSparseDevicePlacement(self): for index_dtype in [dtypes.int32, dtypes.int64]: with self.test_session(force_gpu=test.is_gpu_available()): diff --git a/tensorflow/python/training/device_setter.py b/tensorflow/python/training/device_setter.py index 7f403f49275..85ee10379ad 100644 --- a/tensorflow/python/training/device_setter.py +++ b/tensorflow/python/training/device_setter.py @@ -198,7 +198,7 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps", if ps_ops is None: # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be # placed in the parameter server. - ps_ops = ["Variable", "VariableV2"] + ps_ops = ["Variable", "VariableV2", "VarHandleOp"] if not merge_devices: logging.warning( diff --git a/tensorflow/python/training/device_setter_test.py b/tensorflow/python/training/device_setter_test.py index e05f0f6a1c7..bc29e0d21c5 100644 --- a/tensorflow/python/training/device_setter_test.py +++ b/tensorflow/python/training/device_setter_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_setter @@ -46,6 +47,12 @@ class DeviceSetterTest(test.TestCase): self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) self.assertDeviceEqual("/job:worker/cpu:0", a.device) + def testResource(self): + with ops.device( + device_setter.replica_device_setter(cluster=self._cluster_spec)): + v = resource_variable_ops.ResourceVariable([1, 2]) + self.assertDeviceEqual("/job:ps/task:0", v.device) + def testPS2TasksWithClusterSpecClass(self): with ops.device( device_setter.replica_device_setter(cluster=self._cluster_spec)): diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index cf8692eda13..6d6128d2079 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -252,7 +252,7 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name save_summaries_secs=None, config=None, stop_grace_period_secs=120, - log_step_count_steps=10000): + log_step_count_steps=100): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also diff --git a/tensorflow/tensorboard/backend/application.py b/tensorflow/tensorboard/backend/application.py index 005d1830390..974762822fc 100644 --- a/tensorflow/tensorboard/backend/application.py +++ b/tensorflow/tensorboard/backend/application.py @@ -61,6 +61,7 @@ DATA_PREFIX = '/data' LOGDIR_ROUTE = '/logdir' RUNS_ROUTE = '/runs' PLUGIN_PREFIX = '/plugin' +PLUGINS_LISTING_ROUTE = '/plugins_listing' SCALARS_ROUTE = '/' + event_accumulator.SCALARS IMAGES_ROUTE = '/' + event_accumulator.IMAGES AUDIO_ROUTE = '/' + event_accumulator.AUDIO @@ -152,30 +153,34 @@ class TensorBoardWSGIApp(object): reload_multiplexer(self._multiplexer, path_to_run) self.data_applications = { - DATA_PREFIX + LOGDIR_ROUTE: - self._serve_logdir, - DATA_PREFIX + SCALARS_ROUTE: - self._serve_scalars, - DATA_PREFIX + GRAPH_ROUTE: - self._serve_graph, - DATA_PREFIX + RUN_METADATA_ROUTE: - self._serve_run_metadata, - DATA_PREFIX + HISTOGRAMS_ROUTE: - self._serve_histograms, - DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE: - self._serve_compressed_histograms, - DATA_PREFIX + IMAGES_ROUTE: - self._serve_images, - DATA_PREFIX + INDIVIDUAL_IMAGE_ROUTE: - self._serve_image, + '/app.js': + self._serve_js, DATA_PREFIX + AUDIO_ROUTE: self._serve_audio, + DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE: + self._serve_compressed_histograms, + DATA_PREFIX + GRAPH_ROUTE: + self._serve_graph, + DATA_PREFIX + HISTOGRAMS_ROUTE: + self._serve_histograms, + DATA_PREFIX + IMAGES_ROUTE: + self._serve_images, DATA_PREFIX + INDIVIDUAL_AUDIO_ROUTE: self._serve_individual_audio, + DATA_PREFIX + INDIVIDUAL_IMAGE_ROUTE: + self._serve_image, + DATA_PREFIX + LOGDIR_ROUTE: + self._serve_logdir, + # TODO(chizeng): Delete this RPC once we have skylark rules that obviate + # the need for the frontend to determine which plugins are active. + DATA_PREFIX + PLUGINS_LISTING_ROUTE: + self._serve_plugins_listing, + DATA_PREFIX + RUN_METADATA_ROUTE: + self._serve_run_metadata, DATA_PREFIX + RUNS_ROUTE: self._serve_runs, - '/app.js': - self._serve_js + DATA_PREFIX + SCALARS_ROUTE: + self._serve_scalars, } # Serve the routes from the registered plugins using their name as the route @@ -488,6 +493,21 @@ class TensorBoardWSGIApp(object): }) return query_string + @wrappers.Request.application + def _serve_plugins_listing(self, request): + """Serves an object mapping plugin name to whether it is enabled. + + Args: + request: The werkzeug.Request object. + + Returns: + A werkzeug.Response object. + """ + return http_util.Respond( + request, + {plugin.plugin_name: plugin.is_active() for plugin in self._plugins}, + 'application/json') + @wrappers.Request.application def _serve_runs(self, request): """WSGI app serving a JSON object about runs and tags. diff --git a/tensorflow/tensorboard/backend/application_test.py b/tensorflow/tensorboard/backend/application_test.py index 454ba63e752..002709cd5b0 100644 --- a/tensorflow/tensorboard/backend/application_test.py +++ b/tensorflow/tensorboard/backend/application_test.py @@ -51,6 +51,40 @@ from tensorflow.tensorboard.backend.event_processing import event_multiplexer from tensorflow.tensorboard.plugins import base_plugin +class FakePlugin(base_plugin.TBPlugin): + """A plugin with no functionality.""" + + def __init__(self, plugin_name, is_active_value): + """Constructs a fake plugin. + + Args: + plugin_name: The name of this plugin. + is_active_value: Whether the plugin is active. + """ + self.plugin_name = plugin_name + self._is_active_value = is_active_value + + def get_plugin_apps(self, multiplexer, logdir): + """Returns a mapping from routes to handlers offered by this plugin. + + Args: + multiplexer: The event multiplexer. + logdir: The path to the directory containing logs. + + Returns: + An empty dict. This plugin offers no routes. + """ + return {} + + def is_active(self): + """Returns whether this plugin is active. + + Returns: + A boolean. Whether this plugin is active. + """ + return self._is_active_value + + class TensorboardServerTest(test.TestCase): _only_use_meta_graph = False # Server data contains only a GraphDef @@ -62,7 +96,10 @@ class TensorboardServerTest(test.TestCase): multiplexer = event_multiplexer.EventMultiplexer( size_guidance=application.DEFAULT_SIZE_GUIDANCE, purge_orphaned_data=True) - plugins = [] + plugins = [ + FakePlugin(plugin_name='foo', is_active_value=True), + FakePlugin(plugin_name='bar', is_active_value=False) + ] app = application.TensorBoardWSGIApp( self.temp_dir, plugins, multiplexer, reload_interval=0) try: @@ -124,6 +161,12 @@ class TensorboardServerTest(test.TestCase): parsed_object = self._getJson('/data/logdir') self.assertEqual(parsed_object, {'logdir': self.temp_dir}) + def testPluginsListing(self): + """Test the format of the data/plugins_listing endpoint.""" + parsed_object = self._getJson('/data/plugins_listing') + # Plugin foo is active. Plugin bar is not. + self.assertEqual(parsed_object, {'foo': True, 'bar': False}) + def testRuns(self): """Test the format of the /data/runs endpoint.""" run_json = self._getJson('/data/runs') @@ -484,29 +527,21 @@ class TensorboardSimpleServerConstructionTest(test.TestCase): class TensorBoardApplcationConstructionTest(test.TestCase): def testExceptions(self): - - class UnnamedPlugin(base_plugin.TBPlugin): - - def get_plugin_apps(self): - pass - - class MockPlugin(UnnamedPlugin): - plugin_name = 'mock' - - class OtherMockPlugin(UnnamedPlugin): - plugin_name = 'mock' - logdir = '/fake/foo' multiplexer = event_multiplexer.EventMultiplexer() # Fails if there is an unnamed plugin with self.assertRaises(ValueError): - plugins = [UnnamedPlugin()] + # This plugin lacks a name. + plugins = [FakePlugin(plugin_name=None, is_active_value=True)] application.TensorBoardWSGIApp(logdir, plugins, multiplexer, 0) # Fails if there are two plugins with same name with self.assertRaises(ValueError): - plugins = [MockPlugin(), OtherMockPlugin()] + plugins = [ + FakePlugin(plugin_name='foo', is_active_value=True), + FakePlugin(plugin_name='foo', is_active_value=True), + ] application.TensorBoardWSGIApp(logdir, plugins, multiplexer, 0) diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py index beba28da060..d5a91bbb6a2 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py +++ b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py @@ -438,6 +438,14 @@ class EventAccumulator(object): """ return self._health_pills.Items(node_name) + def GetOpsWithHealthPills(self): + """Determines which ops have at least 1 health pill event. + + Returns: + A list of names of ops with at least 1 health pill event. + """ + return self._health_pills.Keys() + def Graph(self): """Return the graph definition, if there is one. diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py index 38a8cd915fe..3734e470b69 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py +++ b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py @@ -297,8 +297,6 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): acc = ea.EventAccumulator(gen) gen.AddHealthPill(13371337, 41, 'Add', 0, range(1, 13)) gen.AddHealthPill(13381338, 42, 'Add', 1, range(42, 54)) - - acc = ea.EventAccumulator(gen) acc.Reload() # Retrieve the health pills for each node name. @@ -321,6 +319,14 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): value=range(42, 54)), gotten_events[1]) + def testGetOpsWithHealthPills(self): + gen = _EventGenerator(self) + acc = ea.EventAccumulator(gen) + gen.AddHealthPill(13371337, 41, 'Add', 0, range(1, 13)) + gen.AddHealthPill(13381338, 42, 'MatMul', 1, range(42, 54)) + acc.Reload() + self.assertItemsEqual(['Add', 'MatMul'], acc.GetOpsWithHealthPills()) + def testHistograms(self): gen = _EventGenerator(self) acc = ea.EventAccumulator(gen) diff --git a/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py b/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py index bbf958820a0..08e6dbb57d6 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py +++ b/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py @@ -287,6 +287,21 @@ class EventMultiplexer(object): accumulator = self._GetAccumulator(run) return accumulator.HealthPills(node_name) + def GetOpsWithHealthPills(self, run): + """Determines which ops have at least 1 health pill event for a given run. + + Args: + run: The name of the run. + + Raises: + KeyError: If the run is not found, or the node name is not available for + the given run. + + Returns: + The list of names of ops with health pill events. + """ + return self._GetAccumulator(run).GetOpsWithHealthPills() + def Graph(self, run): """Retrieve the graph associated with the provided run. diff --git a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py index ed5cac4014f..ded1856d7e3 100644 --- a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py +++ b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import os import os.path import shutil @@ -45,10 +46,16 @@ def _CreateCleanDirectory(path): class _FakeAccumulator(object): - def __init__(self, path): + def __init__(self, path, health_pill_mapping=None): + """Constructs a fake accumulator with some fake events. + + Args: + path: The path for the run that this accumulator is for. + health_pill_mapping: An optional mapping from Op to health pill strings. + """ self._path = path self.reload_called = False - self._node_names_to_health_pills = {'Add': ['hp1', 'hp2']} + self._node_names_to_health_pills = health_pill_mapping or {} def Tags(self): return {event_accumulator.IMAGES: ['im1', 'im2'], @@ -74,6 +81,9 @@ class _FakeAccumulator(object): health_pills = self._node_names_to_health_pills[node_name] return [self._path + '/' + health_pill for health_pill in health_pills] + def GetOpsWithHealthPills(self): + return self._node_names_to_health_pills.keys() + def Histograms(self, tag_name): return self._TagHelper(tag_name, event_accumulator.HISTOGRAMS) @@ -93,14 +103,13 @@ class _FakeAccumulator(object): self.reload_called = True -# pylint: disable=unused-argument -def _GetFakeAccumulator( - path, - size_guidance=None, - compression_bps=None, - purge_orphaned_data=None): - return _FakeAccumulator(path) -# pylint: enable=unused-argument +def _GetFakeAccumulator(path, + size_guidance=None, + compression_bps=None, + purge_orphaned_data=None, + health_pill_mapping=None): + del size_guidance, compression_bps, purge_orphaned_data # Unused. + return _FakeAccumulator(path, health_pill_mapping=health_pill_mapping) class EventMultiplexerTest(test_util.TensorFlowTestCase): @@ -141,9 +150,27 @@ class EventMultiplexerTest(test_util.TensorFlowTestCase): self.assertEqual(run1_expected, run1_actual) def testHealthPills(self): + self.stubs.Set(event_accumulator, 'EventAccumulator', + functools.partial( + _GetFakeAccumulator, + health_pill_mapping={'Add': ['hp1', 'hp2']})) x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) self.assertEqual(['path1/hp1', 'path1/hp2'], x.HealthPills('run1', 'Add')) + def testGetOpsWithHealthPillsWhenHealthPillsAreNotAvailable(self): + # The event accumulator lacks health pills for the run. + x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) + self.assertItemsEqual([], x.GetOpsWithHealthPills('run1')) + + def testGetOpsWithHealthPillsWhenHealthPillsAreAvailable(self): + # The event accumulator has health pills for the run. + self.stubs.Set(event_accumulator, 'EventAccumulator', + functools.partial( + _GetFakeAccumulator, + health_pill_mapping={'Add': ['hp1', 'hp2']})) + x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) + self.assertItemsEqual(['Add'], x.GetOpsWithHealthPills('run1')) + def testExceptions(self): x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) with self.assertRaises(KeyError): diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html index dbc1dc5c5fa..c90efac1d6b 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html +++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html @@ -34,10 +34,9 @@ Display a warning when there is no data found. and pass the graph either via the constructor, or by calling its add_graph() method. You may want to check out the - + graph visualizer tutorial - - . + .