diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh new file mode 100755 index 00000000000..40b3a60be9f --- /dev/null +++ b/tensorflow/c/generate-pc.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +TF_PREFIX='/usr/local' + +usage() { + echo "Usage: $0 OPTIONS" + echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-v, --version\tset TensorFlow version" + echo -e "-h, --help\tdisplay this message" +} + +# read the options +ARGS=`getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@"` +eval set -- "$ARGS" + +# extract options and their arguments into variables. +while true ; do + case "$1" in + -h|--help) usage ; exit ;; + -p|--prefix) + case "$2" in + "") shift 2 ;; + *) TF_PREFIX=$2 ; shift 2 ;; + esac ;; + -v|--version) + case "$2" in + "") shift 2 ;; + *) TF_VERSION=$2 ; shift 2 ;; + esac ;; + --) shift ; echo "Try '$0 --help' for more information."; exit 1 ;; + *) echo "Internal error! Try '$0 --help' for more information." ; exit 1 ;; + esac +done + +echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" + +cat << EOF > tensorflow.pc +prefix=${TF_PREFIX} +exec_prefix=\${prefix} +libdir=\${exec_prefix}/lib +includedir=\${prefix}/include + +Name: TensorFlow +Version: ${TF_VERSION} +Description: Library for computation using data flow graphs for scalable machine learning +Requires: +Libs: -L\${libdir} -ltensorflow +Cflags: -I\${includedir} +EOF diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 7783bdce3a7..6a249825812 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -260,7 +260,7 @@ TEST_F(GradientsTest, StackUnstack_StopBackprop) { } TEST_F(GradientsTest, DependentGradOutputs) { - // Tests that dependant gradients (in this case the gradients w.r.t to the + // Tests that dependent gradients (in this case the gradients w.r.t to the // output and one input of MatMul) are computed properly. // Create two chained MatMul ops. diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index b144bfc33e4..908aa01a347 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -36,7 +36,7 @@ auto* load_attempt_count = monitoring::Counter<2>::New( "status"); auto* load_latency = monitoring::Counter<1>::New( "/tensorflow/cc/saved_model/load_latency", - "Latency in microseconds for SavedModels that were succesfully loaded.", + "Latency in microseconds for SavedModels that were successfully loaded.", "model_path"); constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index a53e82d34ba..bbdb342a623 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -365,7 +365,7 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config, #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 46d7c03006a..01963c6df46 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -15,7 +15,7 @@ #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 208de5498db..57727766661 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -31,6 +31,8 @@ namespace { inline void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); +#elif defined(COMPILER_MSVC) + return _aligned_malloc(size, minimum_alignment); #else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN void* ptr = nullptr; // posix_memalign requires that the requested alignment be at least @@ -45,7 +47,13 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { free(aligned_memory); } +inline void aligned_free(void* aligned_memory) { +#if defined(COMPILER_MSVC) + _aligned_free(aligned_memory); +#else + free(aligned_memory); +#endif +} size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4fd68a94a58..19f7ff83545 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -170,6 +170,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "slice_ops_test", + size = "small", + srcs = ["slice_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "function_test", size = "small", diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py new file mode 100644 index 00000000000..de91b7b4252 --- /dev/null +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -0,0 +1,132 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slicing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class SliceTest(XLATestCase): + + def test1D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.slice(i, [2], [4]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([2, 3, 4, 5], result) + + def test3D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + with self.test_scope(): + o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[6, 5, 4, 3]]], result) + + +class StridedSliceTest(XLATestCase): + + def test1D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [2], [6], [2]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([2, 4], result) + + def test1DNegtiveStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [6], [2], [-2]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([6, 4], result) + + def test3D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[1, 9]], [[6, 4]]], result) + + def test3DNegativeStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 4, 10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0], [4, 5, 2, 4, 3, 7, 6, 8, 9, + 4]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], [4, 3, 4, 5, 7, 6, 5, 3, 4, 5], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7], [7, 1, 7, 1, 8, 1, 8, 1, 3, + 1]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9], [9, 9, 5, 5, 6, 6, 3, 3, 6, + 6]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[9, 8], [1, 1]], [[2, 4], [5, 7]]], result) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 7a18c1e3750..12537b97654 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc index eff23bd77d2..691a0b972d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,6 @@ EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) -gather_float_int32_xla_impl(float* out, void** data) { +extern "C" void TF_EXPORT gather_float_int32_xla_impl(float* out, void** data) { tensorflow::gather_float_int32_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc index ae31f6f2006..3dff6e2737b 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,6 @@ EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) -gather_float_int64_xla_impl(float* out, void** data) { +extern "C" void TF_EXPORT gather_float_int64_xla_impl(float* out, void** data) { tensorflow::gather_float_int64_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 0033a949a37..afbd64ca503 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -43,7 +44,6 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) -argmax_float_1d_xla_impl(void* out, void** data) { +extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index be8ad2317c9..841ff2f4df7 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -45,7 +46,6 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) -argmax_float_2d_xla_impl(void* out, void** data) { +extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 211412d463d..a6cac62ca4b 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -77,11 +77,9 @@ class StridedSliceOp : public XlaOpKernel { gtl::InlinedVector<int64, 4> dimensions_to_reverse; gtl::InlinedVector<int64, 4> slice_begin, slice_end; + bool simple_strides = true; for (int i = 0; i < begin.size(); ++i) { - // TODO(phawkins): implement strides != 1 when b/30878775 is fixed. - OP_REQUIRES( - ctx, strides[i] == 1 || strides[i] == -1, - errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); + simple_strides &= (std::abs(strides[i]) == 1); if (strides[i] > 0) { slice_begin.push_back(begin[i]); slice_end.push_back(end[i]); @@ -99,6 +97,35 @@ class StridedSliceOp : public XlaOpKernel { slice = ctx->builder()->Rev(slice, dimensions_to_reverse); } + // If at least one of the strides is > 1 (or < -1) then use Slice + // to pull out each of the strided slices, and Concat to put them + // together again. + if (!simple_strides) { + // Re-adjust the begin and end now that the periphery has been + // sliced away. + for (int d = 0; d < strides.size(); ++d) { + slice_end[d] -= slice_begin[d]; + slice_begin[d] = 0; + } + + for (int d = 0; d < strides.size(); ++d) { + int64 stride = std::abs(strides[d]); + if (stride > 1) { + std::vector<xla::ComputationDataHandle> to_concat; + int64 end = slice_end[d]; + for (int64 i = 0; i < end; i += stride) { + slice_begin[d] = i; + slice_end[d] = i + 1; + to_concat.push_back( + ctx->builder()->Slice(slice, slice_begin, slice_end)); + } + slice = ctx->builder()->ConcatInDim(to_concat, d); + slice_begin[d] = 0; + slice_end[d] = to_concat.size(); + } + } + } + slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 10d8b67bbd2..f060f8f2f17 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -89,7 +90,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::F16: - LOG(FATAL) << "f16 literals not yet implemented"; + literal = + *xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value)); + break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; case xla::OPAQUE: @@ -107,6 +110,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { + case xla::F16: + return b->ConstantR0<xla::half>(static_cast<xla::half>(value)); + break; case xla::F32: return b->ConstantR0<float>(static_cast<float>(value)); break; diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h index cd773d64ed4..dca420d6ee3 100644 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h @@ -23,7 +23,7 @@ limitations under the License. // actually used. E.g. some ahead-of-time compiled computations don't need a // thread pool. namespace Eigen { -class ThreadPoolDevice; +struct ThreadPoolDevice; } namespace tensorflow { diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index de09d4b23f8..96685919e91 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -59,7 +59,10 @@ cc_library( name = "types", hdrs = ["types.h"], visibility = [":friends"], - deps = ["//tensorflow/core:lib"], + deps = [ + "//tensorflow/core:lib", + "//third_party/eigen3", + ], ) cc_library( diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index e3bc856fc01..ec4012a7036 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -148,6 +148,9 @@ template <typename T> case S64: return CopyRange<int64>(src_literal, src_base, dest_literal, dest_base, copy_size); + case F16: + return CopyRange<half>(src_literal, src_base, dest_literal, dest_base, + copy_size); case F32: return CopyRange<float>(src_literal, src_base, dest_literal, dest_base, copy_size); @@ -178,6 +181,8 @@ template <typename T> return *LiteralUtil::CreateR0<int32>(0); case S64: return *LiteralUtil::CreateR0<int64>(0); + case F16: + return *LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)); case F32: return *LiteralUtil::CreateR0<float>(0); case F64: @@ -187,8 +192,6 @@ template <typename T> case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - LOG(FATAL) << "f16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; case OPAQUE: @@ -222,7 +225,7 @@ template <typename T> case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -258,7 +261,8 @@ template <typename T> case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0<half>( + static_cast<half>(-std::numeric_limits<float>::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -294,7 +298,8 @@ template <typename T> case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0<half>( + static_cast<half>(std::numeric_limits<float>::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -498,6 +503,8 @@ template <typename T> return tensorflow::strings::StrCat(Get<float>(literal, multi_index)); case F64: return tensorflow::strings::StrCat(Get<double>(literal, multi_index)); + case F16: + return tensorflow::strings::StrCat(Get<half>(literal, multi_index)); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(literal.shape().element_type()), "]"); @@ -652,6 +659,8 @@ template <typename T> return reinterpret_cast<const void*>(literal.f32s().data()); case F64: return reinterpret_cast<const void*>(literal.f64s().data()); + case F16: + return reinterpret_cast<const void*>(literal.f16s().data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(literal.shape().element_type()); @@ -692,6 +701,9 @@ template <typename T> case F64: Resize<double>(num_elements, 0, literal); break; + case F16: + Resize<half>(num_elements, static_cast<half>(0.0f), literal); + break; default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(literal->shape().element_type()); @@ -728,6 +740,9 @@ template <typename T> case F64: actual = literal.f64s_size(); break; + case F16: + actual = literal.f16s().size() / sizeof(half); + break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -818,6 +833,8 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, return EqualElements<float>(literal1, literal2, 0, &multi_index); case F64: return EqualElements<double>(literal1, literal2, 0, &multi_index); + case F16: + return EqualElements<half>(literal1, literal2, 0, &multi_index); default: LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type " << PrimitiveType_Name(literal1.shape().element_type()); @@ -916,6 +933,18 @@ LiteralUtil::GetMutableArraySlice(Literal* literal) { values->size()); } +template <> +/* static */ tensorflow::gtl::MutableArraySlice<half> +LiteralUtil::GetMutableArraySlice<half>(Literal* literal) { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + // TODO - there is an endianess problem here. fix it, or wait for uint16 + // support in protobuf + auto values = literal->mutable_f16s(); + return tensorflow::gtl::MutableArraySlice<half>( + reinterpret_cast<half*>(&(*values)[0]), values->size() / sizeof(half)); +} + template <> /* static */ tensorflow::gtl::ArraySlice<bool> LiteralUtil::GetArraySlice<bool>( const Literal& literal) { @@ -976,6 +1005,15 @@ LiteralUtil::GetArraySlice<double>(const Literal& literal) { return literal.f64s(); } +template <> +/* static */ tensorflow::gtl::ArraySlice<half> LiteralUtil::GetArraySlice<half>( + const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), F16); + return tensorflow::gtl::ArraySlice<half>( + reinterpret_cast<const half*>(literal.f16s().data()), + literal.f16s().size() / sizeof(half)); +} + template <typename NativeT> static bool AllElementsEqualValue(const Literal& literal, NativeT value) { for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { @@ -1015,6 +1053,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return AllElementsEqualValue<float>(literal, value); case F64: return AllElementsEqualValue<double>(literal, value); + case F16: + return AllElementsEqualValue<half>(literal, static_cast<half>(value)); case PRED: if (value == 0) { return AllElementsEqualValue<bool>(literal, false); @@ -1034,6 +1074,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return AllElementsEqualValue<float>(literal, value); case F64: return AllElementsEqualValue<double>(literal, value); + case F16: + return AllElementsEqualValue<half>(literal, static_cast<half>(value)); default: return false; } @@ -1058,6 +1100,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return Get<float>(literal, indices) == 0.0f; case F64: return Get<double>(literal, indices) == 0.0; + case F16: + return Get<half>(literal, indices) == static_cast<half>(0.0f); case PRED: return Get<bool>(literal, indices) == false; default: @@ -1128,4 +1172,15 @@ template <> literal->mutable_f64s()->Resize(num_elements, value); } +template <> +/* static */ void LiteralUtil::Resize<half>(int64 num_elements, half value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_f16s()->resize(num_elements * sizeof(half)); + auto data = GetMutableArraySlice<half>(literal); + for (int i = 0; i < num_elements; i++) { + data[i] = value; + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 0ea71860408..a05dc968ee5 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -505,6 +505,10 @@ template <> /* static */ tensorflow::gtl::ArraySlice<double> LiteralUtil::GetArraySlice<double>(const Literal& literal); +template <> +/* static */ tensorflow::gtl::ArraySlice<half> LiteralUtil::GetArraySlice<half>( + const Literal& literal); + template <> /* static */ tensorflow::gtl::MutableArraySlice<bool> LiteralUtil::GetMutableArraySlice(Literal* literal); @@ -541,6 +545,50 @@ template <> /* static */ tensorflow::gtl::MutableArraySlice<double> LiteralUtil::GetMutableArraySlice(Literal* literal); +template <> +/* static */ tensorflow::gtl::MutableArraySlice<half> +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<bool>(int64 num_elements, bool value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<int8>(int64 num_elements, int8 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<uint8>(int64 num_elements, uint8 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<int32>(int64 num_elements, int32 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<uint32>(int64 num_elements, uint32 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<int64>(int64 num_elements, int64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<uint64>(int64 num_elements, uint64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<float>(int64 num_elements, float value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<double>(int64 num_elements, double value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize<half>(int64 num_elements, half value, + Literal* literal); + template <typename NativeT> /* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) { auto literal = MakeUnique<Literal>(); @@ -770,6 +818,14 @@ template <> return literal.u8s()[linear_index]; } +template <> +/* static */ inline half LiteralUtil::Get<half>( + const Literal& literal, tensorflow::gtl::ArraySlice<int64> multi_index) { + CHECK(literal.shape().element_type() == F16); + int64 linear_index = LinearIndex(literal, multi_index); + return GetArraySlice<half>(literal)[linear_index]; +} + template <typename NativeT> /* static */ void LiteralUtil::Set( Literal* literal, tensorflow::gtl::ArraySlice<int64> multi_index, @@ -834,76 +890,12 @@ template <typename NativeT> } while (IndexUtil::BumpIndices(literal.shape(), &indices)); } -template <> -/* static */ inline void LiteralUtil::PopulateR0<bool>(bool value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<bool>(), {}); - literal->mutable_preds()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0<uint8>(uint8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<uint8>(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0<int8>(int8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<int8>(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0<uint32>(uint32 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<uint32>(), {}); - literal->mutable_u32s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0<int32>(int32 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<int32>(), {}); - literal->mutable_s32s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0<uint64>(uint64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<uint64>(), {}); - literal->mutable_u64s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0<int64>(int64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<int64>(), {}); - literal->mutable_s64s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0<float>(float value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<float>(), {}); - literal->mutable_f32s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0<double>(double value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<double>(), {}); - literal->mutable_f64s()->Add(value); +template <typename NativeT> +/* static */ inline void LiteralUtil::PopulateR0(NativeT value, + Literal* literal) { + *literal->mutable_shape() = ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType<NativeT>(), {}); + Resize<NativeT>(1, value, literal); } template <typename NativeT> @@ -1116,42 +1108,6 @@ template <typename NativeSrcT, typename NativeDestT> return result_literal; } -template <> -/* static */ void LiteralUtil::Resize<bool>(int64 num_elements, bool value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize<int8>(int64 num_elements, int8 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize<uint8>(int64 num_elements, uint8 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize<int32>(int64 num_elements, int32 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize<uint32>(int64 num_elements, uint32 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize<int64>(int64 num_elements, int64 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize<uint64>(int64 num_elements, uint64 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize<float>(int64 num_elements, float value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize<double>(int64 num_elements, double value, - Literal* literal); - template <typename NativeT> /* static */ std::unique_ptr<Literal> LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 7acb9933daa..9a09822174d 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -105,6 +105,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f32_lit = LiteralUtil::CreateR0<float>(3.14f); ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); + + auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f)); + ASSERT_EQ("0.5", LiteralUtil::ToString(*f16_lit)); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -373,6 +376,15 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE( LiteralUtil::IsAll(*LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}), 8)); + half h8(8.0f); + half h9(9.0f); + EXPECT_TRUE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2<half>({{h8}, {h8}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2<half>({{h8}, {h9}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2<half>({{h9}, {h8}}), 8)); + auto uint64_max = std::numeric_limits<uint64>::max(); EXPECT_FALSE(LiteralUtil::IsAll( *LiteralUtil::CreateR2<uint64>( @@ -659,6 +671,30 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); } +TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { + Literal output; + half h(0.25f); + LiteralUtil::PopulateWithValue<half>(h, {}, &output); + auto expected = LiteralUtil::CreateR0<half>(h); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { + Literal output; + half h(0.5f); + LiteralUtil::PopulateWithValue<half>(h, {3}, &output); + auto expected = LiteralUtil::CreateR1<half>({h, h, h}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { + Literal output; + half h(2.0f); + LiteralUtil::PopulateWithValue<half>(h, {2, 2}, &output); + auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2<uint32>( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); @@ -730,6 +766,41 @@ TEST_F(LiteralUtilTest, CopyScalars) { EXPECT_EQ(LiteralUtil::Get<uint32>(*vect, {4}), 17); } +TEST_F(LiteralUtilTest, F16) { + // Verify that the internal data views are consistent and that they + // are in little endian format + // TODO - modify if we make the data format machine endianess dependent + auto m1 = LiteralUtil::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + Literal* l1 = m1.get(); + const char* d1 = (const char*)LiteralUtil::InternalData(*l1); + EXPECT_EQ(d1[0], 0); + EXPECT_EQ(d1[1], 0); + EXPECT_EQ(d1[2], 0); + EXPECT_EQ(d1[3], 0); + EXPECT_EQ(d1[4], 0); + EXPECT_EQ(d1[5], 0); + EXPECT_EQ(d1[6], 0); + EXPECT_EQ(d1[7], 0); + EXPECT_EQ(LiteralUtil::InternalData(*l1), + LiteralUtil::MutableInternalData(l1)); + + half h1(1.0f); + half h2(2.0f); + auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}}); + Literal* l2 = m2.get(); + const char* d2 = (const char*)LiteralUtil::InternalData(*l2); + EXPECT_EQ(d2[0], 0); + EXPECT_EQ(d2[1], 0x3C); + EXPECT_EQ(d2[2], 0); + EXPECT_EQ(d2[3], 0x40); + EXPECT_EQ(d2[4], 0); + EXPECT_EQ(d2[5], 0x40); + EXPECT_EQ(d2[6], 0); + EXPECT_EQ(d2[7], 0x3C); + EXPECT_EQ(LiteralUtil::InternalData(*l2), + LiteralUtil::MutableInternalData(l2)); +} + TEST_F(LiteralUtilTest, Populate) { struct PopulateData { std::vector<int64> dimensions; diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index e3909ae8e97..e4e37177a2d 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType<double>() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType<half>() { + return F16; +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 78f0ee6f592..162a11c7d29 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -75,6 +75,8 @@ template <> PrimitiveType NativeToPrimitiveType<float>(); template <> PrimitiveType NativeToPrimitiveType<double>(); +template <> +PrimitiveType NativeToPrimitiveType<half>(); bool IsFloatingPointType(PrimitiveType type); @@ -150,6 +152,10 @@ template <> struct PrimitiveTypeToNative<F64> { using type = double; }; +template <> +struct PrimitiveTypeToNative<F16> { + using type = half; +}; } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 8e06f0520ed..253de20f251 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include <sched.h> #include <functional> #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 677080a8623..ee772f5c396 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -54,7 +54,7 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; const Eigen::array<DimPair, 1> dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + {DimPair(lhs_contract_dim, rhs_contract_dim)}); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 384a978873d..6f1c97a2334 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -48,7 +48,7 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; const Eigen::array<DimPair, 1> dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + {DimPair(lhs_contract_dim, rhs_contract_dim)}); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1667ab36792..e57eb0bdee6 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -113,7 +113,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall( tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_types, PrimitiveType output_type) const { - // Binary math functions tranform are of type [T] -> T. + // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠ output type: %s ≠ %s", diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index e60978df0a2..36619a84541 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -399,7 +399,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot, llvm::Type* accum_type = target_array.GetElementLlvmType(); llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( accum_type, // The pointee type of the alloca instruction. - "accum_address", // The name of the alloca instuction. + "accum_address", // The name of the alloca instruction. &ir_builder_); // Initialize the accumulator in the preheader to zero. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 485216837dc..383729185df 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -396,7 +396,7 @@ StatusOr<string> CompileModuleToPtx(llvm::Module* module, // The LLVM IR verifier performs sanity checking on the IR. This helps // discover problems and report them in a meaningful manner, rather than let - // later passes report obscure assertions becasue of unfulfilled invariants. + // later passes report obscure assertions because of unfulfilled invariants. module_passes.add(llvm::createVerifierPass()); // Create the function-level pass manager. It needs data layout information diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index 46a5d303b74..61bc6f60557 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -168,7 +168,7 @@ class MatcherBase { virtual ~MatcherBase() {} // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first succesful match, error status otherwise. + // Returns OK on the first successful match, error status otherwise. virtual tensorflow::Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 810eed72f92..a8366ae7949 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1081,7 +1081,7 @@ StatusOr<Layout> InferArrayLayout( *first_buffer_layout)) { // The points-to set is ambiguous for this index and the different source // buffers have different layouts. This case is possible in valid XLA - // computations because we do not propagate BufferLayoutConstaints to all + // computations because we do not propagate BufferLayoutConstraints to all // LogicalBuffers which may alias the constrained LogicalBuffer at some // point in the computation. return FailedPrecondition( @@ -1294,7 +1294,7 @@ Status LayoutAssignment::RunOnComputation( TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(computation->parent())); - // Construct LayoutConstaints with all layout constraints of the computation. + // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(*points_to_analysis, computation); // Add constraints required for correctness on all backends (eg, entry diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index dc54c9defec..f31b703b006 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -29,23 +29,21 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" -extern "C" void __attribute__((visibility("default"))) -R0F32Add2(float* out, float** in) { +extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); *out = **in + 2.0f; } -extern "C" void __attribute__((visibility("default"))) -R2F32ReduceSum(float* out, float** in) { +extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; *out = array[0] + array[1] + array[2] + array[3]; } -extern "C" void __attribute__((visibility("default"))) -Add1ToValues(float* out, float** in) { +extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; out[0] = array[0] + 1; diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 94f34f753b7..cc3c4a2a5e1 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -52,7 +52,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { builder.ConstantR0<float>(42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); - // A result can be transfered an arbitrary number of times. Add an extra + // A result can be transferred an arbitrary number of times. Add an extra // transfer here so we're not just testing that a second call to Transfer // fails. ASSERT_IS_OK(client_->Transfer(*global_data).status()); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index ef81db6fd66..23453db57bc 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -314,7 +314,7 @@ class NearComparator { private: // EXPECTs that the two given scalar values are within the error bound. Keeps - // track of how many mismatches have occured to keep the size of the output + // track of how many mismatches have occurred to keep the size of the output // manageable. template <typename NativeT> bool ExpectValuesNear(NativeT expected, NativeT actual) { diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index d00a3175344..feb2b465fca 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -61,7 +61,7 @@ namespace { class ReduceTest : public ClientLibraryTestBase { protected: ReduceTest() { - // Implementation note: layed out z >> y >> x by default. + // Implementation note: laid out z >> y >> x by default. // clang-format off literal_2d_ = LiteralUtil::CreateR2<float>({ // x0 x1 x2 diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 8258031a2c5..4935648f983 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_ #define TENSORFLOW_COMPILER_XLA_TYPES_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" namespace xla { @@ -32,6 +33,8 @@ using ::tensorflow::uint16; using ::tensorflow::uint32; using ::tensorflow::uint64; +using ::Eigen::half; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index e5b94fcefe3..1239816c50e 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -286,6 +286,7 @@ message Literal { repeated float f32s = 8; repeated double f64s = 9; repeated Literal tuple_literals = 10; + bytes f16s = 11; // Note: the F16s are encoded in little endian byte order } message WindowDimension { diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary_test.cc b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary_test.cc index 0bdfb406641..8de154483e6 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary_test.cc @@ -75,7 +75,7 @@ TEST_F(WeightedQuantilesSummaryTest, BuildFromBuffer) { Summary summary; summary.BuildFromBufferEntries(buffer1_->GenerateEntryList()); - // We expect no approximation error because no compress operation occured. + // We expect no approximation error because no compress operation occurred. EXPECT_EQ(summary.ApproximationError(), 0); // Check first and last elements in the summary. diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 62afa9481b0..bade45e96a3 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -61,15 +61,18 @@ add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC) add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN64 -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) - add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH -D_ITERATOR_DEBUG_LEVEL=0) + add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH) add_definitions(-DTF_COMPILE_LIBRARY) - add_definitions(-DNDEBUG /O2) # Equivalent of -c opt in Bazel. add_definitions(/bigobj /nologo /EHsc /GF /FC /MP /Gm-) # Suppress warnings to reduce build log size. add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") + set(CMAKE_CXX_FLAGS_DEBUG "/D_DEBUG /MDd /Ob0") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /D_ITERATOR_DEBUG_LEVEL=0") + set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /D_ITERATOR_DEBUG_LEVEL=0") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /D_ITERATOR_DEBUG_LEVEL=0") endif() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") diff --git a/tensorflow/contrib/cmake/external/googletest.cmake b/tensorflow/contrib/cmake/external/googletest.cmake index c370f46d2a0..d09bb02890f 100644 --- a/tensorflow/contrib/cmake/external/googletest.cmake +++ b/tensorflow/contrib/cmake/external/googletest.cmake @@ -21,7 +21,7 @@ set(googletest_TAG ec44c6c1675c25b9827aacd08c02433cccde7780) if(WIN32) set(googletest_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/${CMAKE_BUILD_TYPE}/gtest.lib) + ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/$(Configuration)/gtest.lib) else() set(googletest_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/${CMAKE_BUILD_TYPE}/gtest.a) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 46ea519f09d..d7201680ceb 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -21,9 +21,9 @@ set(GRPC_TAG 3bc78cd0b5bd784a235c01612d634b1ec5f8fb97) if(WIN32) set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/${CMAKE_BUILD_TYPE}/grpc++_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/${CMAKE_BUILD_TYPE}/grpc_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/${CMAKE_BUILD_TYPE}/gpr.lib) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/grpc++_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/grpc_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/gpr.lib) else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a diff --git a/tensorflow/contrib/cmake/external/jsoncpp.cmake b/tensorflow/contrib/cmake/external/jsoncpp.cmake index 1ef49681a66..5127d7e8f79 100644 --- a/tensorflow/contrib/cmake/external/jsoncpp.cmake +++ b/tensorflow/contrib/cmake/external/jsoncpp.cmake @@ -23,7 +23,7 @@ set(jsoncpp_LIBRARIES ${jsoncpp_BUILD}/obj/so/libjsoncpp.so) set(jsoncpp_INCLUDES ${jsoncpp_BUILD}) if(WIN32) - set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/${CMAKE_BUILD_TYPE}/jsoncpp.lib) + set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/$(Configuration)/jsoncpp.lib) else() set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/libjsoncpp.a) endif() diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 05e9688d1f0..2b2bd47d1c9 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -21,7 +21,9 @@ set(png_BUILD ${CMAKE_BINARY_DIR}/png/src/png) set(png_INSTALL ${CMAKE_BINARY_DIR}/png/install) if(WIN32) - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + set(png_STATIC_LIBRARIES + debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib + optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) else() set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng12.a) endif() diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index 6cd9c11750d..d600d8c3c0d 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -19,8 +19,10 @@ set(PROTOBUF_URL https://github.com/mrry/protobuf.git) # Includes MSVC fix. set(PROTOBUF_TAG 1d2c7b6c7376f396c8c7dd9b6afd2d4f83f3cb05) if(WIN32) - set(protobuf_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/${CMAKE_BUILD_TYPE}/libprotobuf.lib) - set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/${CMAKE_BUILD_TYPE}/protoc.exe) + set(protobuf_STATIC_LIBRARIES + debug ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobufd.lib + optimized ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobuf.lib) + set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/protoc.exe) set(PROTOBUF_ADDITIONAL_CMAKE_OPTIONS -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=OFF -A x64) else() set(protobuf_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/libprotobuf.a) diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index 8cfde90438c..c8af611e1ea 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -22,7 +22,8 @@ set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d) if(WIN32) set(zlib_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib) + debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib + optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib) else() set(zlib_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index baba1bbb037..7789edf8098 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -861,9 +861,9 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD ${CMAKE_CURRENT_BINARY_DIR}/tf_python/) if(WIN32) add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.dll + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) else() add_custom_command(TARGET tf_python_build_pip_package POST_BUILD diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index c8a827f29db..4abf3b0a645 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -208,6 +208,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py" # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 8dc78da6ba3..60e7c8f160a 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -138,6 +138,7 @@ tf_py_test( "//tensorflow/python:platform_test", ], tags = [ + "no_pip", # b/38283730 "notsan", # Flaky: b/30756419 ], ) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index e520139e659..a4dd3a642fd 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -70,8 +70,7 @@ bool IsBinaryInstalled(const string& binary_name) { const string binary_path = io::JoinPath(dir, binary_name); char absolute_path[PATH_MAX + 1]; if (::realpath(binary_path.c_str(), absolute_path) == NULL) { - LOG(ERROR) << "Invalid binary path: " << binary_path; - return false; + continue; } struct stat statinfo; int result = ::stat(absolute_path, &statinfo); diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 2f579f2d281..36db34f592d 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -29,6 +29,7 @@ from tensorflow.contrib.keras.python.keras.callbacks import ModelCheckpoint from tensorflow.contrib.keras.python.keras.callbacks import ProgbarLogger from tensorflow.contrib.keras.python.keras.callbacks import ReduceLROnPlateau from tensorflow.contrib.keras.python.keras.callbacks import RemoteMonitor +from tensorflow.contrib.keras.python.keras.callbacks import TensorBoard del absolute_import del division diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index fc092fccd78..1724d7599d0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -597,8 +597,8 @@ class DynamicRnnEstimator(estimator.Estimator): `ProblemType.CLASSIFICATION` or `ProblemType.LINEAR_REGRESSION`. prediction_type: whether the `Estimator` should return a value for each step in the sequence, or just a single value for the final time step. - Must be one of `ProblemType.SINGLE_VALUE` or - `ProblemType.MULTIPLE_VALUE`. + Must be one of `PredictionType.SINGLE_VALUE` or + `PredictionType.MULTIPLE_VALUE`. sequence_feature_columns: An iterable containing all the feature columns describing sequence features. All items in the iterable should be instances of classes derived from `FeatureColumn`. diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 1af31b933be..d7ba2209ada 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -331,14 +331,21 @@ def _write_dict_to_summary(output_dir, for key in dictionary: if dictionary[key] is None: continue + if key == 'global_step': + continue value = summary_proto.value.add() value.tag = key if (isinstance(dictionary[key], np.float32) or isinstance(dictionary[key], float)): value.simple_value = float(dictionary[key]) + elif (isinstance(dictionary[key], np.int64) or + isinstance(dictionary[key], np.int32) or + isinstance(dictionary[key], int)): + value.simple_value = int(dictionary[key]) else: - logging.warn('Skipping summary for %s, must be a float or np.float32.', - key) + logging.warn( + 'Skipping summary for %s, must be a float, np.float32, np.int64, np.int32 or int.', + key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 2143f3b9252..b9cd91e519d 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -110,6 +110,7 @@ tensorflow/core/kernels/fake_quant_ops.cc tensorflow/core/kernels/example_parsing_ops.cc tensorflow/core/kernels/dynamic_stitch_op.cc tensorflow/core/kernels/dynamic_partition_op.cc +tensorflow/core/kernels/decode_bmp_op.cc tensorflow/core/kernels/depthtospace_op.cc tensorflow/core/kernels/spacetodepth_op.cc tensorflow/core/kernels/dense_update_ops.cc diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 2a8714644c4..a7e910975fb 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -18,6 +18,7 @@ py_library( "python/training/external_optimizer.py", "python/training/lazy_adam_optimizer.py", "python/training/moving_average_optimizer.py", + "python/training/nadam_optimizer.py", "python/training/variable_clipping_optimizer.py", ], srcs_version = "PY2AND3", @@ -106,6 +107,23 @@ py_test( ], ) +py_test( + name = "nadam_optimizer_test", + srcs = ["python/training/nadam_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "drop_stale_gradient_optimizer_test", srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 6cd68f29a70..656a548cfd5 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -23,16 +23,16 @@ from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * +from tensorflow.contrib.opt.python.training.nadam_optimizer import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['DropStaleGradientOptimizer', - 'ExternalOptimizerInterface', - 'LazyAdamOptimizer', - 'MovingAverageOptimizer', - 'ScipyOptimizerInterface', - 'VariableClippingOptimizer'] +_allowed_symbols = [ + 'DropStaleGradientOptimizer', 'ExternalOptimizerInterface', + 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', + 'ScipyOptimizerInterface', 'VariableClippingOptimizer' +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py index 4f0bc0ce2fd..53232082e16 100644 --- a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py @@ -184,7 +184,7 @@ class DropStaleGradientOptimizerTest(test.TestCase): thread_0.join() thread_1.join() - # With 2 workers and max staleness set to 0, only cheif worker will update + # With 2 workers and max staleness set to 0, only chief worker will update # var_0 and var_1. self.assertAllEqual(1, sessions[0].run(global_step)) self.assertAllEqual(1.0, sessions[0].run(stale_counter)) @@ -250,7 +250,7 @@ class DropStaleGradientOptimizerTest(test.TestCase): thread_1.join() thread_2.join() - # With 3 workers and max staleness set to 0, only cheif worker will update + # With 3 workers and max staleness set to 0, only chief worker will update # var_0 and var_1. self.assertAllEqual(1, sessions[0].run(global_step)) self.assertAllEqual(2.0, sessions[0].run(stale_counter)) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py new file mode 100644 index 00000000000..a4421ecfe6b --- /dev/null +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -0,0 +1,93 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Nadam for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import adam +from tensorflow.python.training import training_ops + + +class NadamOptimizer(adam.AdamOptimizer): + """Optimizer that implements the Nadam algorithm. + + See [Dozat, T., 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). + """ + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + return training_ops.apply_adam( + var, + m, + v, + math_ops.cast(self._beta1_power, var.dtype.base_dtype), + math_ops.cast(self._beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, + use_locking=self._use_locking, + use_nesterov=True).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + return training_ops.resource_apply_adam( + var.handle, + m.handle, + v.handle, + math_ops.cast(self._beta1_power, grad.dtype.base_dtype), + math_ops.cast(self._beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, + use_locking=self._use_locking, + use_nesterov=True) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # m_bar = (1 - beta1) * g_t + beta1 * m_t + m_bar = m_scaled_g_values + beta1_t * m_t + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub( + var, lr * m_bar / (v_sqrt + epsilon_t), use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_bar, v_t]) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py new file mode 100644 index 00000000000..b0a257d264f --- /dev/null +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -0,0 +1,159 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Nadam.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import nadam_optimizer +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def nadam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + m_bar = (1 - beta1) * g_t + beta1 * m_t + + param_t = param - alpha_t * m_bar / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class NadamOptimizerTest(test.TestCase): + + def doTestSparse(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = nadam_optimizer.NadamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Nadam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + + def doTestBasic(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = nadam_optimizer.NadamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Nadam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testBasic(self): + self.doTestBasic(use_resource=False) + + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index db1f939e33f..e0616e06784 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -88,12 +88,20 @@ class BasicRNNCell(RNNCell): class GRUCell(RNNCell): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" - def __init__(self, num_units, input_size=None, activation=tanh, reuse=None): + def __init__(self, + num_units, + input_size=None, + activation=tanh, + reuse=None, + kernel_initializer=None, + bias_initializer=None): super(GRUCell, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._activation = activation + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer @property def state_size(self): @@ -107,10 +115,18 @@ class GRUCell(RNNCell): """Gated recurrent unit (GRU) with nunits cells.""" with vs.variable_scope("gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. - value = sigmoid(_linear([inputs, state], 2 * self._num_units, True, 1.0)) + bias_ones = self._bias_initializer + if self._bias_initializer is None: + dtype = [a.dtype for a in [inputs, state]][0] + bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) + value = sigmoid( + _linear([inputs, state], 2 * self._num_units, True, bias_ones, + self._kernel_initializer)) r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) with vs.variable_scope("candidate"): - c = self._activation(_linear([inputs, r * state], self._num_units, True)) + c = self._activation( + _linear([inputs, r * state], self._num_units, True, + self._bias_initializer, self._kernel_initializer)) new_h = u * state + (1 - u) * c return new_h, new_h @@ -968,14 +984,19 @@ class _SlimRNNCell(RNNCell): return output, state -def _linear(args, output_size, bias, bias_start=0.0): +def _linear(args, + output_size, + bias, + bias_initializer=None, + kernel_initializer=None): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. Args: args: a 2D Tensor or a list of 2D, batch x n, Tensors. output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. - bias_start: starting value to initialize the bias; 0 by default. + bias_initializer: starting value to initialize the bias; None by default. + kernel_initializer: starting value to initialize the weight; None by default. Returns: A 2D Tensor with shape [batch x output_size] equal to @@ -1007,7 +1028,9 @@ def _linear(args, output_size, bias, bias_start=0.0): scope = vs.get_variable_scope() with vs.variable_scope(scope) as outer_scope: weights = vs.get_variable( - _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], dtype=dtype) + _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], + dtype=dtype, + initializer=kernel_initializer) if len(args) == 1: res = math_ops.matmul(args[0], weights) else: @@ -1016,8 +1039,10 @@ def _linear(args, output_size, bias, bias_start=0.0): return res with vs.variable_scope(outer_scope) as inner_scope: inner_scope.set_partitioner(None) + if bias_initializer is None: + bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) biases = vs.get_variable( _BIAS_VARIABLE_NAME, [output_size], dtype=dtype, - initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) + initializer=bias_initializer) return nn_ops.bias_add(res, biases) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index eb494bda4b5..1e8455d89e9 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -497,13 +497,20 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width, time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam + scores_shape = array_ops.shape(scores) scores_flat = control_flow_ops.cond( time > 0, lambda: array_ops.reshape(scores, [batch_size, -1]), lambda: scores[:, 0]) + num_available_beam = control_flow_ops.cond( + time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]), + lambda: math_ops.reduce_prod(scores_shape[2:])) # Pick the next beams according to the specified successors function - next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=beam_width) + next_beam_size = math_ops.minimum( + ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width"), + num_available_beam) + next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) next_beam_scores.set_shape([static_batch_size, beam_width]) word_indices.set_shape([static_batch_size, beam_width]) @@ -561,7 +568,8 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight): """Calculates scores for beam search hypotheses. Args: - log_probs: The log probabilities with shape [batch_size, beam_width]. + log_probs: The log probabilities with shape + `[batch_size, beam_width, vocab_size]`. sequence_lengths: The array of sequence lengths. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 05df05de353..bc687be0abb 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -476,7 +476,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { struct ibv_qp_attr attr; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_RTR; - attr.path_mtu = IBV_MTU_4096; + struct ibv_port_attr port_attr; + CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &port_attr)) + << "Query port failed"; + // This assumes both QP's ports are configured with the same MTU + attr.path_mtu = port_attr.active_mtu; attr.dest_qp_num = remoteAddr.qpn; attr.rq_psn = remoteAddr.psn; attr.max_dest_rd_atomic = 1; @@ -778,11 +782,8 @@ void RdmaTensorBuffer::SendNextItem() { EnqueueItem(key_with_step_id); } }; - // Use default session (legacy_session_) - // TODO use WorkerSessionForSession - // need to pass in session handle - channel_->adapter_->worker_env_->session_mgr->LegacySession() - ->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb); + channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id, + parsed, cb); } } diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index 8cbdfaa9439..5871400f26a 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -29,10 +29,8 @@ namespace tensorflow { class RdmaRemoteRendezvous : public BaseRemoteRendezvous { public: - RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name, - int64 step_id, RdmaMgr* rdma_mgr) - : BaseRemoteRendezvous(env, worker_name, step_id, true), - rdma_mgr_(rdma_mgr) {} + RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr) + : BaseRemoteRendezvous(env, step_id, true), rdma_mgr_(rdma_mgr) {} protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, @@ -133,15 +131,12 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( rb->SendNextItem(); } -RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env, - const string& worker_name, - WorkerCacheInterface* worker_cache) - : BaseRendezvousMgr(env, worker_name) {} +RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env) + : BaseRendezvousMgr(env) {} BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id, - const WorkerEnv* worker_env, - const string& worker_name) { - return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_); + const WorkerEnv* worker_env) { + return new RdmaRemoteRendezvous(worker_env, step_id, rdma_mgr_); } } // end namespace tensorflow diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h index 57cd4bf5e4e..2dedd6c48f9 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h @@ -45,13 +45,12 @@ namespace tensorflow { // RendezvousMgr must have keys generated by Rendezvous::CreateKey. class RdmaRendezvousMgr : public BaseRendezvousMgr { public: - explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache); + explicit RdmaRendezvousMgr(const WorkerEnv* env); void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; } protected: - BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env, - const string& worker_name) override; + BaseRemoteRendezvous* Create(int64 step_id, + const WorkerEnv* worker_env) override; private: RdmaMgr* rdma_mgr_; diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index b061c81d2d8..c3597249354 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -27,10 +27,8 @@ namespace tensorflow { namespace { // static utility function -RendezvousMgrInterface* NewRdmaRendezvousMgr( - const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache) { - return new RdmaRendezvousMgr(env, worker_name, worker_cache); +RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) { + return new RdmaRendezvousMgr(env); } } // namespace @@ -56,7 +54,7 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); *channel_cache = - NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def)); + NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()); const string host_port = (*channel_cache)->TranslateTask(name_prefix); int requested_port; @@ -86,11 +84,7 @@ Status VerbsServer::Init(ServiceInitFunction service_func, rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_); // set rdma_mgr for verbs_service and rdma_rendezvous_mgr verbs_service_->SetRdmaMgr(rdma_mgr_); - // hardcoded to default session (legacy_session_) - // TODO: use WorkerSessionForSession - // need to pass in session handle - dynamic_cast<RdmaRendezvousMgr*>( - worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get()) + dynamic_cast<RdmaRendezvousMgr*>(worker_env()->rendezvous_mgr) ->SetRdmaMgr(rdma_mgr_); } return s; diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 803b5291815..ed3eeb4b215 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2878,6 +2878,8 @@ filegroup( "lib/gif/testdata/scan.gif", # GIF data with optimization "lib/gif/testdata/optimized.gif", + # BMP data + "lib/bmp/testdata/lena.bmp", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index 522db80d7fa..71f82ec9a1b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -227,7 +227,7 @@ void GPUUtil::DeviceToDeviceCopy(DeviceContext* send_dev_context, } // Since we want to use the memory from recv_stream in the // send_device_to_device_stream, add a dependency to make sure the memory is - // truely free. + // truly free. // TODO(zhengxq): remove this dependency when we switch to a better way // to make sure the memory is free. send_device_to_device_stream->ThenWaitFor(recv_stream); @@ -322,7 +322,7 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, done(errors::Internal("No send gpu copy-out-stream is available.")); return; } - // Wait for the recv-stream to make sure the buffer is truely available. + // Wait for the recv-stream to make sure the buffer is truly available. recv_host_to_device_stream->ThenWaitFor(recv_stream); const int64 total_bytes = cpu_tensor->TotalBytes(); diff --git a/tensorflow/core/distributed_runtime/README.md b/tensorflow/core/distributed_runtime/README.md index ab1771e2942..d22cd2a45bc 100644 --- a/tensorflow/core/distributed_runtime/README.md +++ b/tensorflow/core/distributed_runtime/README.md @@ -5,6 +5,4 @@ distributed TensorFlow runtime, using [gRPC](http://grpc.io) for inter-process communication. To learn how to use the distributed runtime to create a TensorFlow cluster, -see the "Distributed TensorFlow" How To, which is available [in this -repository](../../g3doc/how_tos/distributed/index.md), and will be available -on the TensorFlow website after the next version is released. +see the [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed) How-To. diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index b077975ea50..f3bab589a19 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -516,6 +516,7 @@ CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() { RunGraphResponse* InMemoryRunGraphResponse::get_proto() { LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse"; + return NULL; } size_t OwnedProtoRunGraphResponse::num_recvs() const { @@ -634,6 +635,7 @@ RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; } RunStepResponse* InMemoryRunStepResponse::get_proto() { LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse"; + return NULL; } size_t OwnedProtoRunStepResponse::num_tensors() const { diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc index 6179dc05c1e..9262883b2a7 100644 --- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc @@ -73,7 +73,7 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, return; } if (!local_status.ok()) { - // Discard the data if the run wasn't sucessful. + // Discard the data if the run wasn't successful. barrier.DecrementCount(); return; } diff --git a/tensorflow/core/grappler/costs/robust_stats.cc b/tensorflow/core/grappler/costs/robust_stats.cc index dba6efae0fd..9866bc86887 100644 --- a/tensorflow/core/grappler/costs/robust_stats.cc +++ b/tensorflow/core/grappler/costs/robust_stats.cc @@ -47,7 +47,7 @@ static double Median(std::vector<double> &&values) { // nth_element. const auto lower_middle = std::max_element(values.begin(), middle); // Preventing overflow. We know that '*lower_middle <= *middle'. - // If both are on oposite sides of zero, the sum won't overflow, otherwise + // If both are on opposite sides of zero, the sum won't overflow, otherwise // the difference won't overflow. if (*lower_middle <= 0 && *middle >= 0) { return (*lower_middle + *middle) / 2; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 590d334edce..fd6d5ffa56b 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1625,6 +1625,7 @@ cc_library( ":attention_ops", ":colorspace_op", ":crop_and_resize_op", + ":decode_bmp_op", ":decode_image_op", ":draw_bounding_box_op", ":encode_jpeg_op", @@ -1689,6 +1690,12 @@ tf_kernel_library( deps = IMAGE_DEPS, ) +tf_kernel_library( + name = "decode_bmp_op", + prefix = "decode_bmp_op", + deps = IMAGE_DEPS, +) + tf_kernel_library( name = "decode_image_op", prefix = "decode_image_op", @@ -4166,6 +4173,7 @@ filegroup( srcs = [ "batchtospace_op.cc", "ctc_decoder_ops.cc", + "decode_bmp_op.cc", "depthtospace_op.cc", "dynamic_stitch_op.cc", "in_topk_op.cc", @@ -4295,6 +4303,8 @@ filegroup( "decode_image_op.*", "encode_png_op.*", "encode_jpeg_op.*", + "decode_jpeg_op.*", + "decode_gif_op.*", "identity_reader_op.*", "remote_fused_graph_execute_op.*", "remote_fused_graph_rewriter_transform.*", diff --git a/tensorflow/core/kernels/basic_ops_benchmark_test.cc b/tensorflow/core/kernels/basic_ops_benchmark_test.cc index 54532318cec..5726062938b 100644 --- a/tensorflow/core/kernels/basic_ops_benchmark_test.cc +++ b/tensorflow/core/kernels/basic_ops_benchmark_test.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -// We focus on the single thread performance of runing ops. +// We focus on the single thread performance of running ops. static SessionOptions InitOptions() { SessionOptions opts; opts.config.set_intra_op_parallelism_threads(1); diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc new file mode 100644 index 00000000000..086369a9f12 --- /dev/null +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/image_ops.cc + +#include <memory> +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Decode the contents of a BMP file +class DecodeBmpOp : public OpKernel { + public: + explicit DecodeBmpOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_)); + OP_REQUIRES( + context, channels_ == 0 || channels_ == 3 || channels_ == 4, + errors::InvalidArgument("channels must be 0, 3 or 4, got ", channels_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& contents = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), + errors::InvalidArgument("contents must be scalar, got shape ", + contents.shape().DebugString())); + + // Start decoding image to get shape details + const StringPiece input = contents.scalar<string>()(); + + const uint8* img_bytes = reinterpret_cast<const uint8*>(input.data()); + const int32 header_size = internal::SubtleMustCopy( + *(reinterpret_cast<const int32*>(img_bytes + 10))); + const int32 width = internal::SubtleMustCopy( + *(reinterpret_cast<const int32*>(img_bytes + 18))); + const int32 height = internal::SubtleMustCopy( + *(reinterpret_cast<const int32*>(img_bytes + 22))); + const int32 bpp = internal::SubtleMustCopy( + *(reinterpret_cast<const int32*>(img_bytes + 28))); + + if (channels_) { + OP_REQUIRES(context, (channels_ == bpp / 8), + errors::InvalidArgument( + "channels attribute ", channels_, + " does not match bits per pixel from file ", bpp / 8)); + } else { + channels_ = bpp / 8; + } + + // Current implementation only supports 3 or 4 channel + // bitmaps. + OP_REQUIRES(context, (channels_ == 3 || channels_ == 4), + errors::InvalidArgument( + "Number of channels must be 3 or 4, was ", channels_)); + + // if height is negative, data layout is top down + // otherwise, it's bottom up + bool top_down = (height < 0); + + // Decode image, allocating tensor once the image size is known + Tensor* output = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output( + 0, TensorShape({abs(height), width, channels_}), &output)); + + const uint8* bmp_pixels = &img_bytes[header_size]; + + Decode(bmp_pixels, output->flat<uint8>().data(), width, abs(height), + channels_, top_down); + } + + uint8* Decode(const uint8* input, uint8* const output, const int width, + const int height, const int channles, bool top_down); + + private: + int channels_; +}; +REGISTER_KERNEL_BUILDER(Name("DecodeBmp").Device(DEVICE_CPU), DecodeBmpOp); + +uint8* DecodeBmpOp::Decode(const uint8* input, uint8* const output, + const int width, const int height, + const int channels, bool top_down) { + // there may be padding bytes when the width is not a multiple of 4 bytes + // 8 * channels == bits per pixel + int row_size = (8 * channels * width + 31) / 32 * 4; + + for (int i = 0; i < height; i++) { + int src_pos; + int dst_pos; + + for (int j = 0; j < width; j++) { + if (!top_down) { + src_pos = ((height - 1 - i) * row_size) + j * channels; + } else { + src_pos = i * row_size + j * channels; + } + + dst_pos = (i * width + j) * channels; + + switch (channels) { + case 3: + // BGR -> RGB + output[dst_pos] = input[src_pos + 2]; + output[dst_pos + 1] = input[src_pos + 1]; + output[dst_pos + 2] = input[src_pos]; + break; + case 4: + // BGRA -> RGBA + output[dst_pos] = input[src_pos + 2]; + output[dst_pos + 1] = input[src_pos + 1]; + output[dst_pos + 2] = input[src_pos]; + output[dst_pos + 3] = input[src_pos + 3]; + break; + default: + LOG(FATAL) << "Unexpected number of channels: " << channels; + break; + } + } + } + + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc index 9e6d8e42a47..a4814014798 100644 --- a/tensorflow/core/kernels/deep_conv2d.cc +++ b/tensorflow/core/kernels/deep_conv2d.cc @@ -1069,7 +1069,7 @@ struct DeepConv2D<CPUDevice, T> { // Allocate temporary buffer 'buffer2', which is first used for // transformed input tiles, then re-used for transformed output tiles. // Calculate required buffer size for 'buffer2' as max required buffer - // between input and output tranform buffer sizes. + // between input and output transform buffer sizes. const int64 buffer2_tile_transform_size = tile_spatial_size * num_tiles * in_depth; const int64 buffer2_out_transform_size = diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index 26d45f79d82..2e7213f9568 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -441,7 +441,9 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { DepthwiseConv2dNativeOp<CPUDevice, T>); TF_CALL_float(REGISTER_CPU_KERNEL); +#if defined(PLATFORM_WINDOWS) && !defined(_DEBUG) TF_CALL_double(REGISTER_CPU_KERNEL); +#endif #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index c4cfd514c3a..21e6c694642 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS -// See docs in ../ops/fft_ops.cc. +// See docs in ../ops/spectral_ops.cc. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" @@ -29,22 +29,13 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" +#endif namespace tensorflow { -namespace { -// TODO(vrv/zhifengc): Refactor AsDeviceMemory() into GPUUtil. -template <typename T> -perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { - perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); - perftools::gputools::DeviceMemory<T> typed(wrapped); - return typed; -} -} // end namespace - -class FFTGPUBase : public OpKernel { +class FFTBase : public OpKernel { public: - explicit FFTGPUBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { const Tensor& in = ctx->input(0); @@ -97,9 +88,110 @@ class FFTGPUBase : public OpKernel { virtual bool IsForward() const = 0; virtual bool IsReal() const = 0; - private: + // The function that actually computes the FFT. + virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape, + Tensor* out) = 0; +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template <bool Forward, bool _Real, int FFTRank> +class FFTCPU : public FFTBase { + public: + using FFTBase::FFTBase; + + protected: + int Rank() const override { return FFTRank; } + bool IsForward() const override { return Forward; } + bool IsReal() const override { return _Real; } + void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape, - Tensor* out) { + Tensor* out) override { + // Create the axes (which are always trailing). + auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); + auto device = ctx->eigen_device<CPUDevice>(); + + if (!IsReal()) { + auto input = ((Tensor)in).flat_inner_dims<complex64, FFTRank + 1>(); + // Compute the FFT using eigen. + auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); + output.device(device) = input.template fft < Eigen::BothParts, + Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE > (axes); + } else { + if (IsForward()) { + auto input = ((Tensor)in).flat_inner_dims<float, FFTRank + 1>(); + auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); + Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> startIndices; + + // Compute the full FFT using a temporary tensor. + Tensor temp; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(), + in.shape(), &temp)); + auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>(); + full_fft.device(device) = + input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes); + + // Slice away the negative frequency components. + output.device(device) = + full_fft.slice(startIndices, output.dimensions()); + } else { + // TODO: reconstruct the full fft and take the inverse. + ctx->CtxFailureWithWarning( + errors::Unimplemented("IRFFT is not implemented as a CPU kernel")); + } + } + } +}; + +// Use labels to distinguish between internal and open source versions +// of these kernels. +#ifdef PLATFORM_GOOGLE +#define FFT_LABEL "eigen" +#else +#define FFT_LABEL "" +#endif + +REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<true, false, 1>); +REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<false, false, 1>); +REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<true, false, 2>); +REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<false, false, 2>); +REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<true, false, 3>); +REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<false, false, 3>); + +REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<true, true, 1>); +REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<true, true, 2>); +REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL), + FFTCPU<true, true, 3>); + +#undef FFT_LABEL + +#if GOOGLE_CUDA + +namespace { +// TODO(vrv/zhifengc): Refactor AsDeviceMemory() into GPUUtil. +template <typename T> +perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { + perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); + perftools::gputools::DeviceMemory<T> typed(wrapped); + return typed; +} +} // end namespace + +class FFTGPUBase : public FFTBase { + public: + using FFTBase::FFTBase; + + protected: + void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape, + Tensor* out) override { auto* stream = ctx->op_device_context()->stream(); OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); @@ -238,7 +330,6 @@ REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU), FFTGPU<true, false, 3>); REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU), FFTGPU<false, false, 3>); +#endif // GOOGLE_CUDA } // end namespace tensorflow - -#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc index aff78837a32..b672db2016e 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc @@ -114,13 +114,13 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo( for (const std::pair<string, Tensor>& input_node_info : inputs) { const Scope& scope = root.WithOpName(input_node_info.first); Node* ret; - const auto unique_name = scope.GetUniqueNameForOp("PlaceholderV2"); - auto builder = NodeBuilder(unique_name, "PlaceholderV2") + const auto unique_name = scope.GetUniqueNameForOp("Placeholder"); + auto builder = NodeBuilder(unique_name, "Placeholder") .Attr("dtype", input_node_info.second.dtype()) .Attr("shape", input_node_info.second.shape()); scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); - CHECK(scope.ok()); + TF_CHECK_OK(scope.status()); output_list.emplace_back(Output(ret, 0)); input_types.push_back(input_node_info.second.dtype()); } diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 4d4851c70cb..9ffe71e031e 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace tensorflow { +namespace { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -89,6 +90,59 @@ static inline float ComputeIOU(typename TTypes<float, 2>::ConstTensor boxes, return intersection_area / (area_i + area_j - intersection_area); } +void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes, + const Tensor& scores, const Tensor& max_output_size, + const float iou_threshold) { + OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1, + errors::InvalidArgument("iou_threshold must be in [0, 1]")); + + int num_boxes = 0; + ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes); + if (!context->status().ok()) { + return; + } + + const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes); + typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>(); + + std::vector<float> scores_data(num_boxes); + std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin()); + std::vector<int> sorted_indices; + DecreasingArgSort(scores_data, &sorted_indices); + + std::vector<bool> active(num_boxes, true); + std::vector<int> selected; + int num_active = active.size(); + for (int i = 0; i < num_boxes; ++i) { + if (num_active == 0 || selected.size() >= output_size) break; + if (active[i]) { + selected.push_back(sorted_indices[i]); + } else { + continue; + } + for (int j = i + 1; j < num_boxes; ++j) { + if (active[j]) { + float iou = + ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]); + if (iou > iou_threshold) { + active[j] = false; + num_active--; + } + } + } + } + + // Allocate output tensor + Tensor* output = nullptr; + TensorShape output_shape({static_cast<int>(selected.size())}); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + typename TTypes<int, 1>::Tensor selected_indices_data = + output->tensor<int, 1>(); + std::copy_n(selected.begin(), selected.size(), selected_indices_data.data()); +} + +} // namespace + template <typename Device> class NonMaxSuppressionOp : public OpKernel { public: @@ -98,9 +152,6 @@ class NonMaxSuppressionOp : public OpKernel { } void Compute(OpKernelContext* context) override { - OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1, - errors::InvalidArgument("iou_threshold must be in [0, 1]")); - // boxes: [num_boxes, 4] const Tensor& boxes = context->input(0); // scores: [num_boxes] @@ -112,59 +163,48 @@ class NonMaxSuppressionOp : public OpKernel { errors::InvalidArgument("max_output_size must be 0-D, got shape ", max_output_size.shape().DebugString())); - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes); - if (!context->status().ok()) { - return; - } - - const int output_size = - std::min(max_output_size.scalar<int>()(), num_boxes); - typename TTypes<float, 2>::ConstTensor boxes_data = - boxes.tensor<float, 2>(); - - std::vector<float> scores_data(num_boxes); - std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin()); - std::vector<int> sorted_indices; - DecreasingArgSort(scores_data, &sorted_indices); - - std::vector<bool> active(num_boxes, true); - std::vector<int> selected; - int num_active = active.size(); - for (int i = 0; i < num_boxes; ++i) { - if (num_active == 0 || selected.size() >= output_size) break; - if (active[i]) { - selected.push_back(sorted_indices[i]); - } else { - continue; - } - for (int j = i + 1; j < num_boxes; ++j) { - if (active[j]) { - float iou = - ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]); - if (iou > iou_threshold_) { - active[j] = false; - num_active--; - } - } - } - } - - // Allocate output tensor - Tensor* output = nullptr; - TensorShape output_shape({static_cast<int>(selected.size())}); - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - typename TTypes<int, 1>::Tensor selected_indices_data = - output->tensor<int, 1>(); - std::copy_n(selected.begin(), selected.size(), - selected_indices_data.data()); + DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, + iou_threshold_); } private: float iou_threshold_; }; +template <typename Device> +class NonMaxSuppressionV2Op : public OpKernel { + public: + explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // boxes: [num_boxes, 4] + const Tensor& boxes = context->input(0); + // scores: [num_boxes] + const Tensor& scores = context->input(1); + // max_output_size: scalar + const Tensor& max_output_size = context->input(2); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(max_output_size.shape()), + errors::InvalidArgument("max_output_size must be 0-D, got shape ", + max_output_size.shape().DebugString())); + // iou_threshold: scalar + const Tensor& iou_threshold = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), + errors::InvalidArgument("iou_threshold must be 0-D, got shape ", + iou_threshold.shape().DebugString())); + + const float iou_threshold_val = iou_threshold.scalar<float>()(); + + DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, + iou_threshold_val); + } +}; + REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), NonMaxSuppressionOp<CPUDevice>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), + NonMaxSuppressionV2Op<CPUDevice>); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc index 72e368db773..a72c220cbfa 100644 --- a/tensorflow/core/kernels/non_max_suppression_op_test.cc +++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc @@ -173,4 +173,167 @@ TEST_F(NonMaxSuppressionOpTest, TestEmptyInput) { test::ExpectTensorEqual<int>(expected, *GetOutput(0)); } +// +// NonMaxSuppressionV2Op Tests +// + +class NonMaxSuppressionV2OpTest : public OpsTestBase { + protected: + void MakeOp() { + TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV2") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectFromThreeClusters) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues<int>(&expected, {3, 0, 5}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, + TestSelectFromThreeClustersFlippedCoordinates) { + MakeOp(); + AddInputFromArray<float>(TensorShape({6, 4}), + {1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f, + 0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues<int>(&expected, {3, 0, 5}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectAtMostTwoBoxesFromThreeClusters) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {2}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({2})); + test::FillValues<int>(&expected, {3, 0}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, + TestSelectAtMostThirtyBoxesFromThreeClusters) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {30}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues<int>(&expected, {3, 0, 5}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectSingleBox) { + MakeOp(); + AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray<float>(TensorShape({1}), {.9f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues<int>(&expected, {0}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestSelectFromTenIdenticalBoxes) { + MakeOp(); + + int num_boxes = 10; + std::vector<float> corners(num_boxes * 4); + std::vector<float> scores(num_boxes); + for (int i = 0; i < num_boxes; ++i) { + corners[i * 4 + 0] = 0; + corners[i * 4 + 1] = 0; + corners[i * 4 + 2] = 1; + corners[i * 4 + 3] = 1; + scores[i] = .9; + } + AddInputFromArray<float>(TensorShape({num_boxes, 4}), corners); + AddInputFromArray<float>(TensorShape({num_boxes}), scores); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues<int>(&expected, {0}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV2OpTest, TestInconsistentBoxAndScoreShapes) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f}); + AddInputFromArray<int>(TensorShape({}), {30}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + Status s = RunOpKernel(); + + ASSERT_FALSE(s.ok()); + EXPECT_TRUE( + StringPiece(s.ToString()).contains("scores has incompatible shape")) + << s; +} + +TEST_F(NonMaxSuppressionV2OpTest, TestInvalidIOUThreshold) { + MakeOp(); + AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray<float>(TensorShape({1}), {.9f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {1.2f}); + Status s = RunOpKernel(); + + ASSERT_FALSE(s.ok()); + EXPECT_TRUE( + StringPiece(s.ToString()).contains("iou_threshold must be in [0, 1]")) + << s; +} + +TEST_F(NonMaxSuppressionV2OpTest, TestEmptyInput) { + MakeOp(); + AddInputFromArray<float>(TensorShape({0, 4}), {}); + AddInputFromArray<float>(TensorShape({0}), {}); + AddInputFromArray<int>(TensorShape({}), {30}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({0})); + test::FillValues<int>(&expected, {}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + } // namespace tensorflow diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc index 112168b1959..655de2f98f3 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc @@ -89,8 +89,8 @@ static Output BuildPlaceHolderOp(const string& name, const DataType dt, const TensorShape& tensor_shape, Scope* root) { const Scope& scope = root->WithOpName(name); Node* ret; - const string unique_name = scope.GetUniqueNameForOp("PlaceholderV2"); - NodeBuilder builder = NodeBuilder(unique_name, "PlaceholderV2") + const string unique_name = scope.GetUniqueNameForOp("Placeholder"); + NodeBuilder builder = NodeBuilder(unique_name, "Placeholder") .Attr("dtype", dt) .Attr("shape", tensor_shape); scope.UpdateBuilder(&builder); diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 3b2fa296934..d331a8debf0 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -245,12 +245,22 @@ struct ApplyAdamNonCuda { typename TTypes<T>::ConstScalar beta1, typename TTypes<T>::ConstScalar beta2, typename TTypes<T>::ConstScalar epsilon, - typename TTypes<T>::ConstFlat grad) { + typename TTypes<T>::ConstFlat grad, bool use_nesterov) { const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / (T(1) - beta1_power()); + // beta1 == μ + // beta2 == ν + // v == n + // var == θ + m.device(d) += (grad - m) * (T(1) - beta1()); v.device(d) += (grad.square() - v) * (T(1) - beta2()); - var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); + if (use_nesterov) { + var.device(d) -= ((grad * (T(1) - beta1()) + beta1() * m) * alpha) / + (v.sqrt() + epsilon()); + } else { + var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); + } } }; @@ -2248,6 +2258,7 @@ class ApplyAdamOp : public OpKernel { public: explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); } void Compute(OpKernelContext* ctx) override { @@ -2318,17 +2329,18 @@ class ApplyAdamOp : public OpKernel { grad.shape().DebugString())); const Device& device = ctx->template eigen_device<Device>(); - functor::ApplyAdam<Device, T>()(device, var.flat<T>(), m.flat<T>(), - v.flat<T>(), beta1_power.scalar<T>(), - beta2_power.scalar<T>(), lr.scalar<T>(), - beta1.scalar<T>(), beta2.scalar<T>(), - epsilon.scalar<T>(), grad.flat<T>()); + functor::ApplyAdam<Device, T>()( + device, var.flat<T>(), m.flat<T>(), v.flat<T>(), + beta1_power.scalar<T>(), beta2_power.scalar<T>(), lr.scalar<T>(), + beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(), + grad.flat<T>(), use_nesterov_); MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; + bool use_nesterov_; }; using CPUDevice = Eigen::ThreadPoolDevice; @@ -2372,7 +2384,7 @@ namespace functor { typename TTypes<T>::ConstScalar beta1, \ typename TTypes<T>::ConstScalar beta2, \ typename TTypes<T>::ConstScalar epsilon, \ - typename TTypes<T>::ConstFlat grad); \ + typename TTypes<T>::ConstFlat grad, bool use_nesterov); \ extern template struct ApplyAdam<GPUDevice, T>; DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index c96468b2709..11c9faa4ecd 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -123,7 +123,7 @@ struct ApplyAdam { typename TTypes<T>::ConstScalar beta1, typename TTypes<T>::ConstScalar beta2, typename TTypes<T>::ConstScalar epsilon, - typename TTypes<T>::ConstFlat grad); + typename TTypes<T>::ConstFlat grad, bool use_nesterov); }; template <typename Device, typename T> diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index f6acdf2422c..3678b96e98f 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -109,7 +109,7 @@ struct ApplyAdam<GPUDevice, T> { typename TTypes<T>::ConstScalar beta1, typename TTypes<T>::ConstScalar beta2, typename TTypes<T>::ConstScalar epsilon, - typename TTypes<T>::ConstFlat grad) { + typename TTypes<T>::ConstFlat grad, bool use_nesterov) { Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; @@ -122,11 +122,25 @@ struct ApplyAdam<GPUDevice, T> { v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) * (grad.square() - v); - var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() / - (beta1_power.constant(one) - beta1_power)) - .reshape(single) - .broadcast(bcast) * - m / (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); + + if (use_nesterov) { + var.device(d) -= + (lr * (beta2_power.constant(one) - beta2_power).sqrt() / + (beta1_power.constant(one) - beta1_power)) + .reshape(single) + .broadcast(bcast) * + (m * beta1.reshape(single).broadcast(bcast) + + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) * + grad) / + (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); + } else { + var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() / + (beta1_power.constant(one) - beta1_power)) + .reshape(single) + .broadcast(bcast) * + m / + (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); + } } }; diff --git a/tensorflow/core/lib/bmp/testdata/lena.bmp b/tensorflow/core/lib/bmp/testdata/lena.bmp new file mode 100644 index 00000000000..8c4882de4a7 Binary files /dev/null and b/tensorflow/core/lib/bmp/testdata/lena.bmp differ diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc index b957fedc4a0..6e3c083f58a 100644 --- a/tensorflow/core/lib/gtl/inlined_vector_test.cc +++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc @@ -816,7 +816,7 @@ static void BM_StdVectorFillString(int iters, int len) { } testing::ItemsProcessed(int64{iters} * len); // The purpose of the benchmark is to verify that inlined vector is - // efficient when moving is more efficent than copying. To do so, we + // efficient when moving is more efficient than copying. To do so, we // use strings that are larger than the small string optimization. CHECK(!StringRepresentedInline(strings[0])); } diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index cfcf2d1ff93..f2a8d8718d7 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -563,6 +563,28 @@ compression: Compression level. contents: 0-D. PNG-encoded image. )doc"); +// -------------------------------------------------------------------------- +REGISTER_OP("DecodeBmp") + .Input("contents: string") + .Output("image: uint8") + .Attr("channels: int = 0") + .SetShapeFn(DecodeImageShapeFn) + .Doc(R"doc( +Decode the first frame of a BMP-encoded image to a uint8 tensor. + +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the BMP-encoded image. +* 3: output an RGB image. +* 4: output an RGBA image. + +contents: 0-D. The BMP-encoded image. +image: 3-D with shape `[height, width, channels]`. RGB order +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("DecodeGif") .Input("contents: string") @@ -992,16 +1014,13 @@ is agnostic to where the origin is in the coordinate system. Note that this algorithm is invariant to orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system result in the same boxes being selected by the algorithm. - The output of this operation is a set of integers indexing into the input collection of bounding boxes representing the selected boxes. The bounding box coordinates corresponding to the selected indices can then be obtained using the `tf.gather operation`. For example: - selected_indices = tf.image.non_max_suppression( boxes, scores, max_output_size, iou_threshold) selected_boxes = tf.gather(boxes, selected_indices) - boxes: A 2-D float tensor of shape `[num_boxes, 4]`. scores: A 1-D float tensor of shape `[num_boxes]` representing a single score corresponding to each box (each row of boxes). @@ -1013,4 +1032,46 @@ selected_indices: A 1-D integer tensor of shape `[M]` representing the selected indices from the boxes tensor, where `M <= max_output_size`. )doc"); +REGISTER_OP("NonMaxSuppressionV2") + .Input("boxes: float") + .Input("scores: float") + .Input("max_output_size: int32") + .Input("iou_threshold: float") + .Output("selected_indices: int32") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); + }) + .Doc(R"doc( +Greedily selects a subset of bounding boxes in descending order of score, +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system. Note that this +algorithm is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. + +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + + selected_indices = tf.image.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold) + selected_boxes = tf.gather(boxes, selected_indices) + +boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +scores: A 1-D float tensor of shape `[num_boxes]` representing a single + score corresponding to each box (each row of boxes). +max_output_size: A scalar integer tensor representing the maximum number of + boxes to be selected by non max suppression. +iou_threshold: A 0-D float tensor representing the threshold for deciding whether + boxes overlap too much with respect to IOU. +selected_indices: A 1-D integer tensor of shape `[M]` representing the selected + indices from the boxes tensor, where `M <= max_output_size`. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc index b9a37119b74..dea75a1af83 100644 --- a/tensorflow/core/ops/sdca_ops.cc +++ b/tensorflow/core/ops/sdca_ops.cc @@ -105,7 +105,7 @@ example_weights: a vector which contains the weight associated with each example_labels: a vector which contains the label/target associated with each example. sparse_indices: a list of vectors where each value is the indices which has - corresponding weights in sparse_weights. This field maybe ommitted for the + corresponding weights in sparse_weights. This field maybe omitted for the dense approach. sparse_weights: a list of vectors where each value is the weight associated with a sparse feature group. diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index 2027bf4603d..6f7a007f2c7 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -1004,7 +1004,7 @@ out: Same as "var". use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -use_nesterov: If `True`, the tensor passed to compute grad will be +use_nesterov: If `True`, the tensor passed to compute grad will be var - lr * momentum * accum, so in the end, the var you get is actually var - lr * momentum * accum. )doc"); @@ -1043,7 +1043,7 @@ out: Same as "var". use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -use_nesterov: If `True`, the tensor passed to compute grad will be +use_nesterov: If `True`, the tensor passed to compute grad will be var - lr * momentum * accum, so in the end, the var you get is actually var - lr * momentum * accum. )doc"); @@ -1075,7 +1075,7 @@ momentum: Momentum. Must be a scalar. use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -use_nesterov: If `True`, the tensor passed to compute grad will be +use_nesterov: If `True`, the tensor passed to compute grad will be var - lr * momentum * accum, so in the end, the var you get is actually var - lr * momentum * accum. )doc"); @@ -1112,7 +1112,7 @@ momentum: Momentum. Must be a scalar. use_locking: If `True`, updating of the var and accum tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -use_nesterov: If `True`, the tensor passed to compute grad will be +use_nesterov: If `True`, the tensor passed to compute grad will be var - lr * momentum * accum, so in the end, the var you get is actually var - lr * momentum * accum. )doc"); @@ -1150,6 +1150,7 @@ REGISTER_OP("ApplyAdam") .Output("out: Ref(T)") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("use_nesterov: bool = false") .SetShapeFn([](InferenceContext* c) { return ApplyAdamShapeFn(c, false /* sparse */); }) @@ -1175,6 +1176,7 @@ out: Same as "var". use_locking: If `True`, updating of the var, m, and v tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +use_nesterov: If `True`, uses the nesterov update. )doc"); REGISTER_OP("ResourceApplyAdam") @@ -1190,6 +1192,7 @@ REGISTER_OP("ResourceApplyAdam") .Input("grad: T") .Attr("T: numbertype") .Attr("use_locking: bool = false") + .Attr("use_nesterov: bool = false") .SetShapeFn([](InferenceContext* c) { return ApplyAdamShapeFn(c, false /* sparse */); }) @@ -1214,6 +1217,7 @@ grad: The gradient. use_locking: If `True`, updating of the var, m, and v tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. +use_nesterov: If `True`, uses the nesterov update. )doc"); static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) { diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc index 6f29d4597f1..f70b431b652 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider.cc @@ -67,7 +67,7 @@ constexpr char kGceTokenUrl[] = // The authentication token scope to request. constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; -// The default intial delay between retries with exponential backoff. +// The default initial delay between retries with exponential backoff. constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec /// Returns whether the given path points to a readable file. diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index df33cf38c97..566d9aa9084 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -87,10 +87,11 @@ limitations under the License. // 22. Placeholder now can specify and enforce scalar and partial // shapes, particularly when restoring a graph from GraphDef // produced at version 22 or later. (04/10/2016) +// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 22 +#define TF_GRAPH_DEF_VERSION 23 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md index 35867944181..a8c28e98c9b 100644 --- a/tensorflow/docs_src/extend/adding_an_op.md +++ b/tensorflow/docs_src/extend/adding_an_op.md @@ -317,7 +317,7 @@ or `clang` available on your system. The binary PIP package installs the header files and the library that you need to compile your op in locations that are system specific. However, the TensorFlow python library provides the `get_include` function to get the header directory. -Here is the output of this function on a Ubuntu machine. +Here is the output of this function on an Ubuntu machine. ```bash $ python @@ -1215,6 +1215,8 @@ you'll need to specify the path explicitly in the second (g++) command above. For example, add `-L /usr/local/cuda-8.0/lib64/` if your CUDA is installed in `/usr/local/cuda-8.0`. +> Note in some linux settings, additional options to `nvcc` compiling step are needed. Add `-D_MWAITXINTRIN_H_INCLUDED` to the `nvcc` command line to avoid errors from `mwaitxintrin.h`. + ### Implement the gradient in Python {#implement-gradient} Given a graph of ops, TensorFlow uses automatic differentiation diff --git a/tensorflow/docs_src/performance/benchmarks.md b/tensorflow/docs_src/performance/benchmarks.md index 6bbc98ac0d7..47ab028e205 100644 --- a/tensorflow/docs_src/performance/benchmarks.md +++ b/tensorflow/docs_src/performance/benchmarks.md @@ -4,7 +4,7 @@ A selection of image classification models were tested across multiple platforms to create a point of reference for the TensorFlow community. The -[Methodology](#methodology) section details how the test were executed and has +[Methodology](#methodology) section details how the tests were executed and has links to the scripts used. ## Results for image classification models @@ -355,7 +355,7 @@ ResNet-50 | distributed_replicated | n/a | True ResNet-152 | distributed_replicated | n/a | True To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also -ran parameter servers. Equal numbers of parameter servers and work servers were +ran parameter servers. Equal numbers of parameter servers and worker servers were used with the following exceptions: * InceptionV3: 8 instances / 6 parameter servers diff --git a/tensorflow/docs_src/performance/performance_models.md b/tensorflow/docs_src/performance/performance_models.md index d48e008a3ad..d48431eaa08 100644 --- a/tensorflow/docs_src/performance/performance_models.md +++ b/tensorflow/docs_src/performance/performance_models.md @@ -62,12 +62,12 @@ and executed in parallel. The image preprocessing ops include operations such as image decoding, distortion, and resizing. Once the images are through preprocessing, they are concatenated together into 8 -batch size 32 tensors. Rather than use @{tf.concat} for this purpose, which is -implemented as a single op that waits for all the inputs to be ready before -concatenating them together, @{tf.parallel_stack} is used. @{tf.parallel_stack} -allocates an uninitialized tensor as an output, and each input tensor is written -to its designated portion of the output tensor as soon as the input is -available. +tensors each with a batch-size of 32. Rather than using @{tf.concat} for this +purpose, which is implemented as a single op that waits for all the inputs to be +ready before concatenating them together, @{tf.parallel_stack} is used. +@{tf.parallel_stack} allocates an uninitialized tensor as an output, and each +input tensor is written to its designated portion of the output tensor as soon +as the input is available. When all the input tensors are finished, the output tensor is passed along in the graph. This effectively hides all the memory latency with the long tail of @@ -142,7 +142,7 @@ On GPU, NCHW is faster. But on CPU, NHWC is sometimes faster. Building a model to support both data formats keeps the model flexible and capable of operating optimally regardless of platform. Most TensorFlow -operations used by a CNN support both NHWC and NCHW data format. The benchmark +operations used by a CNN support both NHWC and NCHW data formats. The benchmark script was written to support both NCHW and NHWC. NCHW should always be used when training with GPUs. NHWC is sometimes faster on CPU. A flexible model can be trained on GPUs using NCHW with inference done on CPU using NHWC with the diff --git a/tensorflow/docs_src/tutorials/word2vec.md b/tensorflow/docs_src/tutorials/word2vec.md index 348e069ed6d..dfb21334f8e 100644 --- a/tensorflow/docs_src/tutorials/word2vec.md +++ b/tensorflow/docs_src/tutorials/word2vec.md @@ -23,7 +23,7 @@ straight in, feel free to look at the minimalistic implementation in This basic example contains the code needed to download some data, train on it a bit and visualize the result. Once you get comfortable with reading and running the basic version, you can graduate to -[tensorflow_models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py) +[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py) which is a more serious implementation that showcases some more advanced TensorFlow principles about how to efficiently use threads to move data into a text model, how to checkpoint during training, etc. @@ -108,7 +108,7 @@ $$ where \\(\text{score}(w_t, h)\\) computes the compatibility of word \\(w_t\\) with the context \\(h\\) (a dot product is commonly used). We train this model -by maximizing its [log-likelihood](https://en.wikipedia.org/wiki/Likelihood_function) +by maximizing its [log-likelihood](https://en.wikipedia.org/wiki/Likelihood_function) on the training set, i.e. by maximizing $$ @@ -130,7 +130,7 @@ context \\(h\\), *at every training step*. On the other hand, for feature learning in word2vec we do not need a full probabilistic model. The CBOW and skip-gram models are instead trained using a -binary classification objective ([logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)) +binary classification objective ([logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)) to discriminate the real target words \\(w_t\\) from \\(k\\) imaginary (noise) words \\(\tilde w\\), in the same context. We illustrate this below for a CBOW model. For skip-gram the direction is simply inverted. @@ -341,7 +341,7 @@ t-SNE. Et voila! As expected, words that are similar end up clustering nearby each other. For a more heavyweight implementation of word2vec that showcases more of the advanced features of TensorFlow, see the implementation in -[tensorflow_models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). +[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). ## Evaluating Embeddings: Analogical Reasoning @@ -357,7 +357,7 @@ Download the dataset for this task from To see how we do this evaluation, have a look at the `build_eval_graph()` and `eval()` functions in -[tensorflow_models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). +[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). The choice of hyperparameters can strongly influence the accuracy on this task. To achieve state-of-the-art performance on this task requires training over a @@ -385,13 +385,13 @@ your model is seriously bottlenecked on input data, you may want to implement a custom data reader for your problem, as described in @{$new_data_formats$New Data Formats}. For the case of Skip-Gram modeling, we've actually already done this for you as an example in -[tensorflow_models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). +[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). If your model is no longer I/O bound but you want still more performance, you can take things further by writing your own TensorFlow Ops, as described in @{$adding_an_op$Adding a New Op}. Again we've provided an example of this for the Skip-Gram case -[tensorflow_models/tutorials/embedding/word2vec_optimized.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec_optimized.py). +[models/tutorials/embedding/word2vec_optimized.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec_optimized.py). Feel free to benchmark these against each other to measure performance improvements at each stage. diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index 60ba702c10b..270f654ed72 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -28,9 +28,9 @@ on API >= 14 devices. using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to localize and track people in the camera preview in real-time. 3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java): - Uses a model based on [A Learned Representation For Artistic Style] - (https://arxiv.org/abs/1610.07629) to restyle the camera preview image - to that of a number of different artists. + Uses a model based on [A Learned Representation For Artistic + Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview + image to that of a number of different artists. <img src="sample_images/classify1.jpg" width="30%"><img src="sample_images/stylize1.jpg" width="30%"><img src="sample_images/detect1.jpg" width="30%"> diff --git a/tensorflow/examples/label_image/BUILD b/tensorflow/examples/label_image/BUILD index 021372fa7b8..d677e58ac32 100644 --- a/tensorflow/examples/label_image/BUILD +++ b/tensorflow/examples/label_image/BUILD @@ -12,12 +12,32 @@ cc_binary( srcs = [ "main.cc", ], - linkopts = ["-lm"], - deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/core:framework_internal", - "//tensorflow/core:tensorflow", - ], + linkopts = select({ + "//tensorflow:android": [ + "-pie", + "-landroid", + "-ljnigraphics", + "-llog", + "-lm", + "-z defs", + "-s", + "-Wl,--exclude-libs,ALL", + ], + "//conditions:default": ["-lm"], + }), + deps = select({ + "//tensorflow:android": [ + # cc:cc_ops is used to include image ops (for label_image) + # Jpg, gif, and png related code won't be included + "//tensorflow/cc:cc_ops", + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework_internal", + "//tensorflow/core:tensorflow", + ], + }), ) filegroup( diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 3351109f454..90454bd7ac1 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -30,6 +30,9 @@ limitations under the License. // the top of the main() function. // // The googlenet_graph.pb file included by default is created from Inception. +// +// Note that, for GIF inputs, to reuse existing code, only single-frame ones +// are supported. #include <fstream> #include <utility> @@ -103,7 +106,12 @@ Status ReadTensorFromImageFile(const string& file_name, const int input_height, image_reader = DecodePng(root.WithOpName("png_reader"), file_reader, DecodePng::Channels(wanted_channels)); } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) { - image_reader = DecodeGif(root.WithOpName("gif_reader"), file_reader); + // gif decoder returns 4-D tensor, remove the first dim + image_reader = + Squeeze(root.WithOpName("squeeze_first_dim"), + DecodeGif(root.WithOpName("gif_reader"), file_reader)); + } else if (tensorflow::StringPiece(file_name).ends_with(".bmp")) { + image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader); } else { // Assume if it's neither a PNG nor a GIF then it must be a JPEG. image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader, diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 42d7f484644..c08fa9b1457 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -20,7 +20,7 @@ package org.tensorflow; * * <p>Instances of a Graph are thread-safe. * - * <p><b>WARNING:</b> Resources consumed by the Graph object msut be explicitly freed by invoking + * <p><b>WARNING:</b> Resources consumed by the Graph object must be explicitly freed by invoking * the {@link #close()} method then the Graph object is no longer needed. */ public final class Graph implements AutoCloseable { diff --git a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java index cd59cf504a7..38ffa2a8e19 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java @@ -73,6 +73,29 @@ public final class OperationBuilder { return this; } + /** + * Ensure that the operation does not execute before the control operation does. + * + * <p>A control input is an Operation that must be executed before running the operation currently + * being built. + * + * <p>For example, an Assert operation may be added as a control input for this operation. The + * Assert now behaves as a pre-condition that will always verify itself before running the + * operation. + * + * @param control operation that must be executed before running this operation. + * @return the OperationBuilder instance for chaining. + */ + public OperationBuilder addControlInput(Operation control) { + Graph.Reference r = graph.ref(); + try { + addControlInput(unsafeNativeHandle, control.getUnsafeNativeHandle()); + } finally { + r.close(); + } + return this; + } + public OperationBuilder addInputList(Output[] inputs) { Graph.Reference r = graph.ref(); try { @@ -244,6 +267,8 @@ public final class OperationBuilder { private static native void addInputList(long handle, long[] opHandles, int[] indices); + private static native void addControlInput(long handle, long opHandle); + private static native void setDevice(long handle, String device); // The names of all the setAttr* family functions below correspond to the C library types, not the diff --git a/tensorflow/java/src/main/native/operation_builder_jni.cc b/tensorflow/java/src/main/native/operation_builder_jni.cc index 5724c54f911..4c54eecd9b5 100644 --- a/tensorflow/java/src/main/native/operation_builder_jni.cc +++ b/tensorflow/java/src/main/native/operation_builder_jni.cc @@ -115,6 +115,20 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInputList( TF_AddInputList(d, o.get(), n); } +JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addControlInput( + JNIEnv* env, jclass clazz, jlong handle, jlong op_handle) { + if (op_handle == 0) { + throwException(env, kIllegalStateException, + "control input is not valid, " + "perhaps the Graph containing it has been closed()?"); + return; + } + TF_Operation* control = reinterpret_cast<TF_Operation*>(op_handle); + TF_OperationDescription* d = requireHandle(env, handle); + if (d == nullptr) return; + TF_AddControlInput(d, control); +} + JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setDevice( JNIEnv* env, jclass clazz, jlong handle, jstring device) { TF_OperationDescription* d = requireHandle(env, handle); diff --git a/tensorflow/java/src/main/native/operation_builder_jni.h b/tensorflow/java/src/main/native/operation_builder_jni.h index ae953c0fd63..9b64c328203 100644 --- a/tensorflow/java/src/main/native/operation_builder_jni.h +++ b/tensorflow/java/src/main/native/operation_builder_jni.h @@ -55,6 +55,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInput( JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addInputList( JNIEnv *, jclass, jlong, jlongArray, jintArray); +/* + * Class: org_tensorflow_OperationBuilder + * Method: addControlInput + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_addControlInput( + JNIEnv *, jclass, jlong, jlong); + /* * Class: org_tensorflow_OperationBuilder * Method: setDevice diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java index 9dc400a3d3a..b3bc3aaef9c 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java @@ -149,6 +149,33 @@ public class OperationBuilderTest { } } + @Test + public void addControlInput() { + try (Graph g = new Graph(); + Session s = new Session(g); + Tensor yes = Tensor.create(true); + Tensor no = Tensor.create(false)) { + Output placeholder = TestUtil.placeholder(g, "boolean", DataType.BOOL); + Operation check = + g.opBuilder("Assert", "assert") + .addInput(placeholder) + .addInputList(new Output[] {placeholder}) + .build(); + Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build(); + + // No problems when the Assert check succeeds + s.runner().feed(placeholder, yes).addTarget(noop).run(); + + // Exception thrown by the execution of the Assert node + try { + s.runner().feed(placeholder, no).addTarget(noop).run(); + fail("Did not run control operation."); + } catch (IllegalArgumentException e) { + // expected + } + } + } + private static boolean hasNode(Graph g, String name) { return g.operation(name) != null; } diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index c97efc127af..33ffe3d81ed 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -739,7 +739,7 @@ def _model_fn_args(fn): def _verify_model_fn_args(model_fn, params): """Verifies model fn arguments.""" - args = _model_fn_args(model_fn) + args = set(_model_fn_args(model_fn)) if 'features' not in args: raise ValueError('model_fn (%s) must include features argument.' % model_fn) if 'labels' not in args: @@ -752,7 +752,10 @@ def _verify_model_fn_args(model_fn, params): logging.warning('Estimator\'s model_fn (%s) includes params ' 'argument, but params are not passed to Estimator.', model_fn) - non_valid_args = list(set(args) - _VALID_MODEL_FN_ARGS) + if tf_inspect.ismethod(model_fn): + if 'self' in args: + args.remove('self') + non_valid_args = list(args - _VALID_MODEL_FN_ARGS) if non_valid_args: raise ValueError('model_fn (%s) has following not expected args: %s' % (model_fn, non_valid_args)) @@ -814,13 +817,20 @@ def _write_dict_to_summary(output_dir, for key in dictionary: if dictionary[key] is None: continue + if key == 'global_step': + continue value = summary_proto.value.add() value.tag = key if (isinstance(dictionary[key], np.float32) or isinstance(dictionary[key], float)): value.simple_value = float(dictionary[key]) + elif (isinstance(dictionary[key], np.int64) or + isinstance(dictionary[key], np.int32) or + isinstance(dictionary[key], int)): + value.simple_value = int(dictionary[key]) else: - logging.warn('Skipping summary for %s, must be a float or np.float32.', - key) + logging.warn( + 'Skipping summary for %s, must be a float, np.float32, np.int64, np.int32 or int.', + key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 5c8c3a26267..b25c3ba93ab 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -256,6 +256,18 @@ class EstimatorConstructorTest(test.TestCase): features, labels, 'something') estimator.Estimator(model_fn=new_model_fn) + def test_if_model_fn_is_a_member_function_of_a_class(self): + + class ModelFnClass(object): + + def __init__(self): + estimator.Estimator(model_fn=self.model_fn) + + def model_fn(self, features, labels, mode): + _, _, _ = features, labels, mode + + ModelFnClass() + def dummy_input_fn(): return ({'x': constant_op.constant([[1], [1]])}, diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 73c810711f4..d72700ea2c7 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -66,6 +66,9 @@ class Dimension(object): def __int__(self): return self._value + def __long__(self): + return self._value + def __index__(self): # Allow use in Python 3 range return self._value diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 1e445591b14..10811100010 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -368,7 +368,9 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): np_dt = dtype.as_numpy_dtype else: np_dt = None - if np.prod(shape) == 0: + # If shape is None, numpy.prod returns None when dtype is not set, but raises + # exception when dtype is set to np.int64 + if shape is not None and np.prod(shape, dtype=np.int64) == 0: nparray = np.empty(shape, dtype=np_dt) else: _AssertCompatible(values, dtype) @@ -414,7 +416,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): shape_size = nparray.size else: shape = [int(dim) for dim in shape] - shape_size = np.prod(shape) + shape_size = np.prod(shape, dtype=np.int64) is_same_size = shape_size == nparray.size if verify_shape: @@ -491,7 +493,7 @@ def MakeNdarray(tensor): """ shape = [d.size for d in tensor.tensor_shape.dim] - num_elements = np.prod(shape) + num_elements = np.prod(shape, dtype=np.int64) tensor_dtype = dtypes.as_dtype(tensor.dtype) dtype = tensor_dtype.as_numpy_dtype diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 349787d135c..b83937ab264 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -214,6 +214,19 @@ tf_py_test( ], ) +tf_py_test( + name = "decode_bmp_op_test", + size = "small", + srcs = ["decode_bmp_op_test.py"], + additional_deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:image_ops", + "//tensorflow/python:nn_grad", + ], +) + tf_py_test( name = "decode_image_op_test", size = "small", diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py new file mode 100644 index 00000000000..e7a8ac3af6c --- /dev/null +++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py @@ -0,0 +1,116 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DecodeBmpOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import image_ops +from tensorflow.python.platform import test + + +class DecodeBmpOpTest(test.TestCase): + + def testex1(self): + img_bytes = [[[0, 0, 255], [0, 255, 0]], [[255, 0, 0], [255, 255, 255]]] + # Encoded BMP bytes from Wikipedia + encoded_bytes = [ + 0x42, + 0x40, + 0x46, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0x36, + 0, + 0, + 0, + 0x28, + 0, + 0, + 0, + 0x2, + 0, + 0, + 0, + 0x2, + 0, + 0, + 0, + 0x1, + 0, + 0x18, + 0, + 0, + 0, + 0, + 0, + 0x10, + 0, + 0, + 0, + 0x13, + 0xb, + 0, + 0, + 0x13, + 0xb, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0xff, + 0xff, + 0xff, + 0xff, + 0, + 0, + 0xff, + 0, + 0, + 0, + 0xff, + 0, + 0, + 0, + ] + + byte_string = bytes(bytearray(encoded_bytes)) + img_in = constant_op.constant(byte_string, dtype=dtypes.string) + decode = array_ops.squeeze(image_ops.decode_bmp(img_in)) + + with self.test_session(): + decoded = decode.eval() + self.assertAllEqual(decoded, img_bytes) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/decode_image_op_test.py b/tensorflow/python/kernel_tests/decode_image_op_test.py index b457b5cc866..58280432d63 100644 --- a/tensorflow/python/kernel_tests/decode_image_op_test.py +++ b/tensorflow/python/kernel_tests/decode_image_op_test.py @@ -33,6 +33,17 @@ prefix_path = "tensorflow/core/lib" class DecodeImageOpTest(test.TestCase): + def testBmp(self): + # Read a real bmp and verify shape + path = os.path.join(prefix_path, "bmp", "testdata", "lena.bmp") + with self.test_session(use_gpu=True) as sess: + bmp0 = io_ops.read_file(path) + image0 = image_ops.decode_image(bmp0) + image1 = image_ops.decode_bmp(bmp0) + bmp0, image0, image1 = sess.run([bmp0, image0, image1]) + self.assertEqual(len(bmp0), 4194) + self.assertAllEqual(image0, image1) + def testGif(self): # Read some real GIFs path = os.path.join(prefix_path, "gif", "testdata", "scan.gif") diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py index f6699649489..84928bd2e1f 100644 --- a/tensorflow/python/kernel_tests/fft_ops_test.py +++ b/tensorflow/python/kernel_tests/fft_ops_test.py @@ -39,30 +39,26 @@ class BaseFFTOpsTest(test.TestCase): self._CompareBackward(x, rank, fft_length, use_placeholder) def _CompareForward(self, x, rank, fft_length=None, use_placeholder=False): - if test.is_gpu_available(cuda_only=True): - x_np = self._npFFT(x, rank, fft_length) - if use_placeholder: - x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) - x_tf = self._tfFFT(x_ph, rank, fft_length, use_gpu=True, - feed_dict={x_ph: x}) - else: - x_tf = self._tfFFT(x, rank, fft_length, use_gpu=True) + x_np = self._npFFT(x, rank, fft_length) + if use_placeholder: + x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) + x_tf = self._tfFFT( + x_ph, rank, fft_length, use_gpu=True, feed_dict={x_ph: x}) + else: + x_tf = self._tfFFT(x, rank, fft_length, use_gpu=True) - # GPU/Forward - self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) + self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) def _CompareBackward(self, x, rank, fft_length=None, use_placeholder=False): - if test.is_gpu_available(cuda_only=True): - x_np = self._npIFFT(x, rank, fft_length) - if use_placeholder: - x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) - x_tf = self._tfIFFT(x_ph, rank, fft_length, use_gpu=True, - feed_dict={x_ph: x}) - else: - x_tf = self._tfIFFT(x, rank, fft_length, use_gpu=True) + x_np = self._npIFFT(x, rank, fft_length) + if use_placeholder: + x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) + x_tf = self._tfIFFT( + x_ph, rank, fft_length, use_gpu=True, feed_dict={x_ph: x}) + else: + x_tf = self._tfIFFT(x, rank, fft_length, use_gpu=True) - # GPU/Backward - self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) + self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) def _checkGradComplex(self, func, x, y, result_is_complex=True, use_gpu=False): @@ -151,12 +147,11 @@ class FFTOpsTest(BaseFFTOpsTest): raise ValueError("invalid rank") def testEmpty(self): - if test.is_gpu_available(cuda_only=True): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - x = np.zeros((0,) * dims).astype(np.complex64) - self.assertEqual(x.shape, self._tfFFT(x, rank).shape) - self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + x = np.zeros((0,) * dims).astype(np.complex64) + self.assertEqual(x.shape, self._tfFFT(x, rank).shape) + self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) def testBasic(self): for rank in VALID_FFT_RANKS: @@ -184,41 +179,41 @@ class FFTOpsTest(BaseFFTOpsTest): self._Compare(gen((4,) * dims), rank) def testError(self): - if test.is_gpu_available(cuda_only=True): - for rank in VALID_FFT_RANKS: - for dims in xrange(0, rank): - x = np.zeros((1,) * dims).astype(np.complex64) - with self.assertRaisesWithPredicateMatch( - ValueError, "Shape must be .*rank {}.*".format(rank)): - self._tfFFT(x, rank) - with self.assertRaisesWithPredicateMatch( - ValueError, "Shape must be .*rank {}.*".format(rank)): - self._tfIFFT(x, rank) + for rank in VALID_FFT_RANKS: + for dims in xrange(0, rank): + x = np.zeros((1,) * dims).astype(np.complex64) + with self.assertRaisesWithPredicateMatch( + ValueError, "Shape must be .*rank {}.*".format(rank)): + self._tfFFT(x, rank) + with self.assertRaisesWithPredicateMatch( + ValueError, "Shape must be .*rank {}.*".format(rank)): + self._tfIFFT(x, rank) def testGrad_Simple(self): - if test.is_gpu_available(cuda_only=True): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 2): - re = np.ones(shape=(4,) * dims, dtype=np.float32) / 10.0 - im = np.zeros(shape=(4,) * dims, dtype=np.float32) - self._checkGradComplex(self._tfFFTForRank(rank), re, im, use_gpu=True) - self._checkGradComplex( - self._tfIFFTForRank(rank), re, im, use_gpu=True) + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 2): + re = np.ones(shape=(4,) * dims, dtype=np.float32) / 10.0 + im = np.zeros(shape=(4,) * dims, dtype=np.float32) + self._checkGradComplex(self._tfFFTForRank(rank), re, im, use_gpu=True) + self._checkGradComplex(self._tfIFFTForRank(rank), re, im, use_gpu=True) def testGrad_Random(self): - if test.is_gpu_available(cuda_only=True): - np.random.seed(54321) - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 2): - re = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 - im = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 - self._checkGradComplex(self._tfFFTForRank(rank), re, im, use_gpu=True) - self._checkGradComplex( - self._tfIFFTForRank(rank), re, im, use_gpu=True) + np.random.seed(54321) + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 2): + re = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 + im = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 + self._checkGradComplex(self._tfFFTForRank(rank), re, im, use_gpu=True) + self._checkGradComplex(self._tfIFFTForRank(rank), re, im, use_gpu=True) class RFFTOpsTest(BaseFFTOpsTest): + def _CompareBackward(self, x, rank, fft_length=None, use_placeholder=False): + if test.is_gpu_available(cuda_only=True): + super(RFFTOpsTest, self)._CompareBackward(x, rank, fft_length, + use_placeholder) + def _tfFFT(self, x, rank, fft_length=None, use_gpu=False, feed_dict=None): with self.test_session(use_gpu=use_gpu): return self._tfFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict) @@ -268,12 +263,12 @@ class RFFTOpsTest(BaseFFTOpsTest): raise ValueError("invalid rank") def testEmpty(self): - if test.is_gpu_available(cuda_only=True): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - x = np.zeros((0,) * dims).astype(np.float32) - self.assertEqual(x.shape, self._tfFFT(x, rank).shape) - x = np.zeros((0,) * dims).astype(np.complex64) + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + x = np.zeros((0,) * dims).astype(np.float32) + self.assertEqual(x.shape, self._tfFFT(x, rank).shape) + x = np.zeros((0,) * dims).astype(np.complex64) + if test.is_gpu_available(cuda_only=True): self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) def testBasic(self): @@ -327,36 +322,35 @@ class RFFTOpsTest(BaseFFTOpsTest): self._CompareBackward(gen_complex(complex_dims), rank, (size,) * rank) def testError(self): - if test.is_gpu_available(cuda_only=True): - for rank in VALID_FFT_RANKS: - for dims in xrange(0, rank): - x = np.zeros((1,) * dims).astype(np.complex64) - with self.assertRaisesWithPredicateMatch( - ValueError, "Shape must be .*rank {}.*".format(rank)): - self._tfFFT(x, rank) - with self.assertRaisesWithPredicateMatch( - ValueError, "Shape must be .*rank {}.*".format(rank)): - self._tfIFFT(x, rank) - for dims in xrange(rank, rank + 2): - x = np.zeros((1,) * rank) + for rank in VALID_FFT_RANKS: + for dims in xrange(0, rank): + x = np.zeros((1,) * dims).astype(np.complex64) + with self.assertRaisesWithPredicateMatch( + ValueError, "Shape must be .*rank {}.*".format(rank)): + self._tfFFT(x, rank) + with self.assertRaisesWithPredicateMatch( + ValueError, "Shape must be .*rank {}.*".format(rank)): + self._tfIFFT(x, rank) + for dims in xrange(rank, rank + 2): + x = np.zeros((1,) * rank) - # Test non-rank-1 fft_length produces an error. - fft_length = np.zeros((1, 1)).astype(np.int32) - with self.assertRaisesWithPredicateMatch(ValueError, - "Shape must be .*rank 1"): - self._tfFFT(x, rank, fft_length) - with self.assertRaisesWithPredicateMatch(ValueError, - "Shape must be .*rank 1"): - self._tfIFFT(x, rank, fft_length) + # Test non-rank-1 fft_length produces an error. + fft_length = np.zeros((1, 1)).astype(np.int32) + with self.assertRaisesWithPredicateMatch(ValueError, + "Shape must be .*rank 1"): + self._tfFFT(x, rank, fft_length) + with self.assertRaisesWithPredicateMatch(ValueError, + "Shape must be .*rank 1"): + self._tfIFFT(x, rank, fft_length) - # Test wrong fft_length length. - fft_length = np.zeros((rank + 1,)).astype(np.int32) - with self.assertRaisesWithPredicateMatch( - ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): - self._tfFFT(x, rank, fft_length) - with self.assertRaisesWithPredicateMatch( - ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): - self._tfIFFT(x, rank, fft_length) + # Test wrong fft_length length. + fft_length = np.zeros((rank + 1,)).astype(np.int32) + with self.assertRaisesWithPredicateMatch( + ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): + self._tfFFT(x, rank, fft_length) + with self.assertRaisesWithPredicateMatch( + ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): + self._tfIFFT(x, rank, fft_length) def testGrad_Simple(self): if test.is_gpu_available(cuda_only=True): diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 8594d811e89..5c6d309e6c7 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -354,9 +354,15 @@ def _PreventGradientGrad(op, _): def _GatherGrad(op, grad): """Gradient for Gather op.""" # params can be large, so colocate the shape calculation with it. + # + # params can be very large for sparse model, array_ops.shape raises + # exception on the Windows platform when any dimension is larger than + # int32. params_shape is not used in optimizer apply_sparse gradients, + # so it's fine to convert it back to int32 regardless of truncation. params = op.inputs[0] with ops.colocate_with(params): - params_shape = array_ops.shape(params) + params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) + params_shape = math_ops.to_int32(params_shape) # Build appropriately shaped IndexedSlices indices = op.inputs[1] diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index c29ae26f04e..75c67dcb3c2 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -18,6 +18,7 @@ See the @{$python/image} guide. +@@decode_bmp @@decode_gif @@decode_jpeg @@encode_jpeg diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 3e140ce0478..b16c1863ddb 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -1312,16 +1312,18 @@ def adjust_saturation(image, saturation_factor, name=None): def decode_image(contents, channels=None, name=None): - """Convenience function for `decode_gif`, `decode_jpeg`, and `decode_png`. + """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, + and `decode_png`. - Detects whether an image is a GIF, JPEG, or PNG, and performs the appropriate - operation to convert the input bytes `string` into a `Tensor` of type `uint8`. + Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the + appropriate operation to convert the input bytes `string` into a `Tensor` of + type `uint8`. Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as - opposed to `decode_jpeg` and `decode_png`, which return 3-D arrays - `[height, width, num_channels]`. Make sure to take this into account when - constructing your graph if you are intermixing GIF files with JPEG and/or PNG - files. + opposed to `decode_bmp`, `decode_jpeg` and `decode_png`, which return 3-D + arrays `[height, width, num_channels]`. Make sure to take this into account + when constructing your graph if you are intermixing GIF files with BMP, JPEG, + and/or PNG files. Args: contents: 0-D `string`. The encoded image bytes. @@ -1331,8 +1333,8 @@ def decode_image(contents, channels=None, name=None): Returns: `Tensor` with type `uint8` with shape `[height, width, num_channels]` for - JPEG and PNG images and shape `[num_frames, height, width, 3]` for GIF - images. + BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for + GIF images. Raises: ValueError: On incorrect number of channels. @@ -1342,12 +1344,21 @@ def decode_image(contents, channels=None, name=None): raise ValueError('channels must be in (None, 0, 1, 3, 4)') substr = string_ops.substr(contents, 0, 3) - def _gif(): + def _bmp(): """Decodes a GIF image.""" - # Create assert op to check that bytes are GIF decodable - is_gif = math_ops.equal(substr, b'\x47\x49\x46', name='is_gif') - decode_msg = 'Unable to decode bytes as JPEG, PNG, or GIF' - assert_decode = control_flow_ops.Assert(is_gif, [decode_msg]) + signature = string_ops.substr(contents, 0, 2) + # Create assert op to check that bytes are BMP decodable + is_bmp = math_ops.equal(signature, 'BM', name='is_bmp') + decode_msg = 'Unable to decode bytes as JPEG, PNG, GIF, or BMP' + assert_decode = control_flow_ops.Assert(is_bmp, [decode_msg]) + bmp_channels = 0 if channels is None else channels + good_channels = math_ops.not_equal(bmp_channels, 1, name='check_channels') + channels_msg = 'Channels must be in (None, 0, 3) when decoding BMP images' + assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) + with ops.control_dependencies([assert_decode, assert_channels]): + return gen_image_ops.decode_bmp(contents) + + def _gif(): # Create assert to make sure that channels is not set to 1 # Already checked above that channels is in (None, 0, 1, 3) @@ -1358,9 +1369,14 @@ def decode_image(contents, channels=None, name=None): ) channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images' assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) - with ops.control_dependencies([assert_decode, assert_channels]): + with ops.control_dependencies([assert_channels]): return gen_image_ops.decode_gif(contents) + def check_gif(): + # Create assert op to check that bytes are GIF decodable + is_gif = math_ops.equal(substr, b'\x47\x49\x46', name='is_gif') + return control_flow_ops.cond(is_gif, _gif, _bmp, name='cond_gif') + def _png(): """Decodes a PNG image.""" return gen_image_ops.decode_png(contents, channels) @@ -1368,7 +1384,7 @@ def decode_image(contents, channels=None, name=None): def check_png(): """Checks if an image is PNG.""" is_png = math_ops.equal(substr, b'\211PN', name='is_png') - return control_flow_ops.cond(is_png, _png, _gif, name='cond_png') + return control_flow_ops.cond(is_png, _png, check_gif, name='cond_png') def _jpeg(): """Decodes a jpeg image.""" diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index c06c07ccd87..e4eaeff67ad 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1324,7 +1324,7 @@ def crelu(features, name=None): Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the *negative* part of the activation. Note that as a result this non-linearity doubles the depth of the activations. - Source: https://arxiv.org/abs/1603.05201 + Source: [Understanding and Improving Convolutional Neural Networks via Concatenated Rectified Linear Units. W. Shang, et al.](https://arxiv.org/abs/1603.05201) Args: features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`, @@ -1342,6 +1342,7 @@ def crelu(features, name=None): def relu6(features, name=None): """Computes Rectified Linear 6: `min(max(features, 0), 6)`. + Source: [Convolutional Deep Belief Networks on CIFAR-10. A. Krizhevsky](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf) Args: features: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`, diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 7543c3c111b..3c1ab3248b6 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -1764,7 +1764,7 @@ def sparse_transpose(sp_input, perm=None, name=None): Raises: TypeError: If `sp_input` is not a `SparseTensor`. """ - with ops.op_scope([sp_input], name, "SparseTranspose") as name: + with ops.name_scope(name, "SparseTranspose", [sp_input]) as name: if perm is None: rank = array_ops.rank(sp_input) perm = (rank - 1) - math_ops.range(0, rank, 1) diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py index e42f3b639ca..ffd7c12c427 100644 --- a/tensorflow/python/training/momentum.py +++ b/tensorflow/python/training/momentum.py @@ -28,8 +28,11 @@ class MomentumOptimizer(optimizer.Optimizer): """Optimizer that implements the Momentum algorithm. Computes (if `use_nesterov = False`): - accumulation = momentum * accumulation + gradient - variable -= learning_rate * accumulation + + ``` + accumulation = momentum * accumulation + gradient + variable -= learning_rate * accumulation + ``` Note that in the dense version of this algorithm, `accumulation` is updated and applied regardless of a gradient's value, whereas the sparse version (when diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index 6002f36bacb..8f7790f2996 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -40,6 +40,10 @@ tf_module { name: "crop_to_bounding_box" argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "decode_bmp" + argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } member_method { name: "decode_gif" argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index a0e0f88d9df..c801ceff938 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -17,6 +17,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ unzip \ zip \ zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -46,18 +48,6 @@ COPY run_jupyter.sh / # Set up Bazel. -# We need to add a custom PPA to pick up JDK8, since trusty doesn't -# have an openjdk8 backport. openjdk-r is maintained by a reliable contributor: -# Matthias Klose (https://launchpad.net/~doko). It will do until -# we either update the base image beyond 14.04 or openjdk-8 is -# finally backported to trusty; see e.g. -# https://bugs.launchpad.net/trusty-backports/+bug/1368094 -RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - apt-get update && \ - apt-get install -y --no-install-recommends openjdk-8-jdk openjdk-8-jre-headless && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - # Running bazel inside a `docker build` command causes trouble, cf: # https://github.com/bazelbuild/bazel/issues/134 # The easiest solution is to set up a bazelrc file forcing --batch. diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 743c05ef887..24350c507e7 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -17,6 +17,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ unzip \ zip \ zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -46,18 +48,6 @@ COPY run_jupyter.sh / # Set up Bazel. -# We need to add a custom PPA to pick up JDK8, since trusty doesn't -# have an openjdk8 backport. openjdk-r is maintained by a reliable contributor: -# Matthias Klose (https://launchpad.net/~doko). It will do until -# we either update the base image beyond 14.04 or openjdk-8 is -# finally backported to trusty; see e.g. -# https://bugs.launchpad.net/trusty-backports/+bug/1368094 -RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - apt-get update && \ - apt-get install -y --no-install-recommends openjdk-8-jdk openjdk-8-jre-headless && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - # Running bazel inside a `docker build` command causes trouble, cf: # https://github.com/bazelbuild/bazel/issues/134 # The easiest solution is to set up a bazelrc file forcing --batch. diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index 9bb438fa719..7e8c51efe6a 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -215,7 +215,7 @@ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ --out_graph=optimized_inception_graph.pb \ --inputs='Mul' \ --outputs='softmax' \ ---transforms='\ +--transforms=' strip_unused_nodes(type=float, shape="1,299,299,3") fold_constants(ignore_errors=true) fold_batch_norms @@ -431,12 +431,11 @@ graph: ```bash bazel build tensorflow/tools/graph_transforms:transform_graph bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ ---logtostderr \ --in_graph=/tmp/quantized_inception.pb \ --out_graph=/tmp/logged_quantized_inception.pb \ --inputs=Mul \ --outputs=softmax \ ---transforms='\ +--transforms=' insert_logging(op=RequantizationRange, show_name=true, message="__requant_min_max:")\ ' ``` @@ -450,12 +449,10 @@ log: bazel build tensorflow/examples/label_image:label_image bazel-bin/tensorflow/examples/label_image/label_image \ --image=${HOME}/Downloads/grace_hopper.jpg \ ---logtostderr \ --input_layer=Mul \ --output_layer=softmax \ --graph=/tmp/logged_quantized_inception.pb \ --labels=${HOME}/Downloads/imagenet_comp_graph_label_strings.txt \ ---logtostderr \ 2>/tmp/min_max_log_small.txt ``` diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index ca421595387..e4925780457 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -94,7 +94,6 @@ void PrintBenchmarkUsage(const std::vector<const NodeDef*>& placeholders, std::cout << "bazel run tensorflow/tools/benchmark:benchmark_model --"; std::cout << " --graph=" << graph_path; std::cout << " --show_flops"; - std::cout << " --logtostderr"; std::cout << " --input_layer=" << input_layer_value; std::cout << " --input_layer_type=" << input_layer_type_value; std::cout << " --input_layer_shape=" << input_layer_shape_value; diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions.cc b/tensorflow/tools/proto_text/gen_proto_text_functions.cc index 17ab542a598..ecb29a65a08 100644 --- a/tensorflow/tools/proto_text/gen_proto_text_functions.cc +++ b/tensorflow/tools/proto_text/gen_proto_text_functions.cc @@ -130,6 +130,7 @@ int MainImpl(int argc, char** argv) { const string path = output_root + "/" + proto_path_no_suffix + suffix; FILE* f = fopen(path.c_str(), "w"); + if (f == nullptr) return -1; if (fwrite(data.c_str(), 1, data.size(), f) != data.size()) { return -1; } diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py index 3b458dc6aa3..df71840b64d 100644 --- a/tensorflow/tools/quantization/quantize_graph_test.py +++ b/tensorflow/tools/quantization/quantize_graph_test.py @@ -688,7 +688,7 @@ class QuantizeGraphTest(test.TestCase): def test_quantized_input_range_bias_add(self): input_shape = [1, 1, 2, 6] - input_n = quantize_graph.create_node("PlaceholderV2", "input", []) + input_n = quantize_graph.create_node("Placeholder", "input", []) quantize_graph.set_attr_dtype(input_n, "dtype", dtypes.float32) quantize_graph.set_attr_shape(input_n, "shape", input_shape) offset_n = quantize_graph.create_constant_node( @@ -713,7 +713,7 @@ class QuantizeGraphTest(test.TestCase): shapes = [[3, 2], [2, 4]] inputs = [] for i, shape in enumerate(shapes): - node = quantize_graph.create_node("PlaceholderV2", "input_%s" % i, []) + node = quantize_graph.create_node("Placeholder", "input_%s" % i, []) quantize_graph.set_attr_dtype(node, "dtype", dtypes.float32) quantize_graph.set_attr_shape(node, "shape", shape) inputs.append(node) diff --git a/third_party/grpc.BUILD b/third_party/grpc.BUILD index 1699f6a854a..b79259618f2 100644 --- a/third_party/grpc.BUILD +++ b/third_party/grpc.BUILD @@ -177,8 +177,6 @@ cc_library( "include", ], linkopts = ["-lpthread"], - deps = [ - ], ) cc_library(