diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index eb3da111a24..7ad33c8947c 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; namespace xla { @@ -54,11 +55,18 @@ StatusOr HloPassPipeline::Run(HloModule* module) { << tensorflow::str_util::Join(disabled_passes, ", "); } - auto run_invariant_checkers = [this, module]() -> Status { + auto run_invariant_checkers = [this, + module](const string& message) -> Status { for (auto& invariant_checker : invariant_checkers_) { VLOG(1) << " Invariant checker " << invariant_checker->name(); - TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module)); - TF_RET_CHECK(!changed) << "invariant checkers must not change the graph"; + StatusOr changed_status = invariant_checker->Run(module); + if (!changed_status.ok()) { + return Status(changed_status.status().code(), + StrCat(changed_status.status().error_message(), + "\n\nFailed ", message)); + } + TF_RET_CHECK(!changed_status.ValueOrDie()) + << "invariant checkers must not change the graph"; } return Status::OK(); }; @@ -66,6 +74,8 @@ StatusOr HloPassPipeline::Run(HloModule* module) { string prefix = name().ToString() + ": pipeline start"; bool changed = false; string message; + TF_RETURN_IF_ERROR( + run_invariant_checkers(StrCat("before running pipeline: ", name()))); for (auto& pass : passes_) { if (disabled_passes.count(pass->name().ToString()) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() @@ -80,14 +90,14 @@ StatusOr HloPassPipeline::Run(HloModule* module) { StrAppend(&message, prefix, ", before ", pass->name()); DumpModule(*module, message); - TF_RETURN_IF_ERROR(run_invariant_checkers()); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); + TF_RETURN_IF_ERROR( + run_invariant_checkers(StrCat("after running pass: ", pass->name()))); changed |= changed_this_pass; prefix.clear(); StrAppend(&prefix, name(), ": after ", pass->name()); } - TF_RETURN_IF_ERROR(run_invariant_checkers()); DumpModule(*module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 20152cf0cef..6e5d7bca75c 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1202,7 +1202,7 @@ StatusOr HloRematerialization::RematerializeComputation( StatusOr HloRematerialization::Run( HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes) { + int64 memory_limit_bytes, RematerializationSizes* sizes) { // The sequence is constructed entirely by this method. TF_RET_CHECK(sequence->empty()); @@ -1319,13 +1319,20 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; + if (sizes != nullptr) { + sizes->before_bytes = before_peak_memory; + sizes->after_bytes = current_peak_memory; + } + XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); if (current_peak_memory > memory_limit_bytes) { - LOG(WARNING) << "Can't reduce memory use below " - << HumanReadableNumBytes(memory_limit_bytes) - << " by rematerialization (only reduced to " - << HumanReadableNumBytes(current_peak_memory) << ")"; + LOG(WARNING) << tensorflow::strings::Printf( + "Can't reduce memory use below %s (%lld bytes) by rematerialization; " + "only reduced to %s (%lld bytes)", + HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes, + HumanReadableNumBytes(current_peak_memory).c_str(), + current_peak_memory); } return changed; @@ -1334,9 +1341,10 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - SequentialHloOrdering::HloModuleSequence* sequence) { + SequentialHloOrdering::HloModuleSequence* sequence, + RematerializationSizes* sizes) { HloRematerialization remat(size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes); + return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 42c279d440b..11f79a6d415 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -28,6 +28,13 @@ class HloRematerialization { public: using ShapeSizeFunction = std::function; + // Helper struct that communicates the before / after sizes for the + // rematerialization process. + struct RematerializationSizes { + int64 before_bytes; + int64 after_bytes; + }; + // Rematerialize HLO instructions in the given module to reduce peak memory // use below memory_limit_bytes where memory use is defined as the total size // of all live HLO instruction values. Parameters and constants are included @@ -46,6 +53,9 @@ class HloRematerialization { // rematerialization. This is the order in which HLO instructions should // be emitted to minimize memory use. // + // sizes: Optional outparam that indicates the peak memory usage of the HLO + // module before/after rematerialization. + // // Returns whether any instructions were rematerialized. If memory use is // already below the given limit then no instructions are rematerialized and // false is returned. @@ -55,8 +65,8 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, - SequentialHloOrdering::HloModuleSequence* sequence); + HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence, + RematerializationSizes* sizes = nullptr); protected: HloRematerialization(const ShapeSizeFunction& size_function) @@ -69,7 +79,7 @@ class HloRematerialization { // contains the memory-minimizing order in which to emit the HLO instructions. StatusOr Run(HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit); + int64 memory_limit, RematerializationSizes* sizes); // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index 33327dc60fb..8275531111c 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -123,7 +123,8 @@ StatusOr ReducePrecisionInsertion::insert_on_inputs( } if (instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) { + (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop || + instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) { // Insert the reduce-precision operation inside the fusion computation, // after the corresponding parameter instruction. TF_ASSIGN_OR_RETURN( @@ -173,7 +174,8 @@ StatusOr ReducePrecisionInsertion::insert_on_outputs( } if (instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) { + (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop || + instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) { // Insert the reduce-precision operation as the last operation inside // the fusion computation. HloInstruction* fusion_root = instruction->fused_expression_root(); diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 15850bf0a4e..eb02f20457e 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -240,6 +240,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows. "${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker. + "${tensorflow_source_dir}/tensorflow/python/training/server_lib_test.py" # Test occasionally deadlocks. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" diff --git a/tensorflow/core/framework/tensor_testutil.h b/tensorflow/core/framework/tensor_testutil.h index ab224aa7188..4c216a84f04 100644 --- a/tensorflow/core/framework/tensor_testutil.h +++ b/tensorflow/core/framework/tensor_testutil.h @@ -166,10 +166,11 @@ struct Expector { static void Equal(const Tensor& x, const Tensor& y) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); - auto a = x.flat(); - auto b = y.flat(); - for (int i = 0; i < a.size(); ++i) { - ExpectEqual(a(i), b(i)); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + ExpectEqual(a[i], b[i]); } } }; @@ -182,10 +183,11 @@ struct Expector { static void Equal(const Tensor& x, const Tensor& y) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); - auto a = x.flat(); - auto b = y.flat(); - for (int i = 0; i < a.size(); ++i) { - ExpectEqual(a(i), b(i)); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + ExpectEqual(a[i], b[i]); } } @@ -199,10 +201,11 @@ struct Expector { static void Near(const Tensor& x, const Tensor& y, const double abs_err) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); - auto a = x.flat(); - auto b = y.flat(); - for (int i = 0; i < a.size(); ++i) { - Near(a(i), b(i), abs_err, i); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + Near(a[i], b[i], abs_err, i); } } }; diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 9de8642d417..bbb9e36fc9d 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -529,7 +529,7 @@ void LaunchConv2DOp::operator()( const int64 out_depths = GetTensorDim(*output, data_format, 'C'); const int64 patch_rows = filter.dim_size(0); const int64 patch_cols = filter.dim_size(1); - if (padding == Eigen::PADDING_SAME) { + if (padding == SAME) { // Total padding on rows and cols is // Pr = (R' - 1) * S + Kr - R // Pc = (C' - 1) * S + Kc - C diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index a1f60019141..fb03adb7a53 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -91,7 +91,14 @@ class CTCLossOp : public OpKernel { OP_REQUIRES(ctx, batch_size != 0, errors::InvalidArgument("batch_size must not be 0")); - TensorShape labels_shape({batch_size, max_time}); + // Figure out the maximum label length to use as sparse tensor dimension. + auto labels_indices_t = labels_indices->matrix(); + int64 max_label_len = 0; + for (int i = 0; i < labels_indices->dim_size(0); i++) { + max_label_len = std::max(max_label_len, labels_indices_t(i, 1) + 1); + } + + TensorShape labels_shape({batch_size, max_label_len}); std::vector order{0, 1}; sparse::SparseTensor labels_sp(*labels_indices, *labels_values, labels_shape, order); diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 73b7f821c82..04c7554a580 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -392,6 +392,7 @@ class TensorFlowTestCase(googletest.TestCase): self._cached_session = None def setUp(self): + logging.info("SET UP: %s" % str(self)) self._ClearCachedSession() random.seed(random_seed.DEFAULT_GRAPH_SEED) np.random.seed(random_seed.DEFAULT_GRAPH_SEED) @@ -406,6 +407,7 @@ class TensorFlowTestCase(googletest.TestCase): ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED def tearDown(self): + logging.info("TEAR DOWN: %s" % str(self)) for thread in self._threads: self.assertFalse(thread.is_alive(), "A checkedThread did not terminate") diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 1471b5909eb..2cae16f44cc 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -505,6 +505,8 @@ class ResourceVariable(variables.Variable): def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: + if self._trainable: + tape.watch(self._handle) value = resource_gather( self._handle, indices, dtype=self._dtype, name=name) return array_ops.identity(value) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 176719fabb4..ef342fe1272 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -699,9 +699,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "bazel_toolchains", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/bccee4855c049d34bac481083b4c68e2fab8cc50.tar.gz", - "https://github.com/bazelbuild/bazel-toolchains/archive/bccee4855c049d34bac481083b4c68e2fab8cc50.tar.gz", + "http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/9dbd803ad3b9447430a296810197b09b3a710956.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/9dbd803ad3b9447430a296810197b09b3a710956.tar.gz", ], - sha256 = "3903fd93b96b42067e00b7973a2c16c34e761ad7a0b55e1557d408f352849e41", - strip_prefix = "bazel-toolchains-bccee4855c049d34bac481083b4c68e2fab8cc50", + sha256 = "0799aa12db5260a499beb40f81744e760c59d055bfc5d271dd2c2ed4d5419faa", + strip_prefix = "bazel-toolchains-9dbd803ad3b9447430a296810197b09b3a710956", )