From 497606904be87f7a4078ed7ee0784afaa094b258 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Wed, 24 Feb 2016 14:46:43 -0800 Subject: [PATCH] Fix build issue with safety fix to gather and scatter Change: 115495726 --- tensorflow/core/kernels/BUILD | 12 ++ tensorflow/core/kernels/bounds_check.h | 38 +++++ tensorflow/core/kernels/gather_op.cc | 143 +++++++++--------- tensorflow/core/kernels/scatter_op.cc | 97 +++++++----- tensorflow/core/kernels/scatter_op.h | 9 +- tensorflow/core/kernels/scatter_op_gpu.cu.cc | 9 +- .../python/kernel_tests/gather_op_test.py | 8 + .../python/kernel_tests/scatter_ops_test.py | 4 +- 8 files changed, 201 insertions(+), 119 deletions(-) create mode 100644 tensorflow/core/kernels/bounds_check.h diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f1fe55e1fc2..4e5073b8bee 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -30,6 +30,15 @@ cc_library( ], ) +cc_library( + name = "bounds_check", + hdrs = ["bounds_check.h"], + deps = [ + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + tf_kernel_library( name = "concat_lib", srcs = ["concat_lib_cpu.cc"], @@ -226,6 +235,7 @@ tf_kernel_libraries( "where_op", ], deps = [ + ":bounds_check", ":concat_lib", ":fill_functor", ":ops_util", @@ -874,6 +884,7 @@ tf_kernel_libraries( ], deps = [ ":assign_op", + ":bounds_check", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:state_ops_op_lib", @@ -955,6 +966,7 @@ filegroup( "assign_op.h", "bias_op.cc", "bias_op.h", + "bounds_check.h", "cast_op.cc", "cast_op.h", "concat_lib.h", diff --git a/tensorflow/core/kernels/bounds_check.h b/tensorflow/core/kernels/bounds_check.h new file mode 100644 index 00000000000..286ef8959c9 --- /dev/null +++ b/tensorflow/core/kernels/bounds_check.h @@ -0,0 +1,38 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_ +#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_ + +#include + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// Check that 0 <= index < limit using a single comparison, assuming +// that 0 <= limit if Index is signed. Intended for use in performance +// critical contexts where 0 <= index < limit is almost always true. +template +EIGEN_ALWAYS_INLINE bool FastBoundsCheck(Index index, Index limit) { + typedef typename std::make_unsigned::type UIndex; + return TF_PREDICT_TRUE(static_cast(index) < + static_cast(limit)); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_ diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index 207337778c8..d7a4e20fbd0 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -18,36 +18,52 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" namespace tensorflow { namespace { +// Returns -1 on success or a nonnegative i s.t., indices[i] is bad. template -void HandleCopies(const Tensor& Tparams, - typename TTypes::ConstVec& Tindices, int slice_elems, - typename TTypes::Matrix Tout) { - const int N = Tindices.dimension(0); - const auto& Tparams_flat = Tparams.flat_outer_dims(); - T* Tout_base = &Tout(0, 0); - const T* Tparams_base = &Tparams_flat(0, 0); - const size_t slice_bytes = slice_elems * sizeof(T); +Index HandleCopies(const Tensor& params, + typename TTypes::ConstVec indices, Index slice_elems, + typename TTypes::Matrix out) { + const int N = indices.dimension(0); + const auto& params_flat = params.flat_outer_dims(); + const Index limit = params.dim_size(0); + T* out_base = &out(0, 0); + const T* params_base = ¶ms_flat(0, 0); if (static_slice_elems >= 0) { // Give compiler static knowledge of the number of elements/bytes CHECK_EQ(static_slice_elems, slice_elems); slice_elems = static_slice_elems; } + // Compute slice_bytes here so that static knowledge is available + const size_t slice_bytes = slice_elems * sizeof(T); for (int i = 0; i < N; i++) { - int j = i + 1; + const int j = i + 1; if (j < N) { - port::prefetch(&Tparams_flat(Tindices(j), 0)); - port::prefetch(&Tout(j, 0)); + port::prefetch(¶ms_flat(indices(j), 0)); + port::prefetch(&out(j, 0)); + } + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = indices(i); + if (!FastBoundsCheck(index, limit)) return i; + // Copy using memcpy if possible, otherwise an Eigen loop + if (Allocator::is_simple::value) { + memcpy(out_base + i * slice_elems, params_base + index * slice_elems, + slice_bytes); + } else { + out.template chip<0>(i) = params_flat.template chip<0>(index); } - memcpy(Tout_base + i * slice_elems, - Tparams_base + Tindices(i) * slice_elems, slice_bytes); } + return -1; } } // anonymous namespace @@ -64,78 +80,67 @@ class GatherOp : public OpKernel { const DataType dt = DataTypeToEnum::v(); const DataType index_t = DataTypeToEnum::v(); OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt})); - OP_REQUIRES_OK(c, c->GetAttr("validate_indices", &validate_indices_)); + // We used to grab the validate_indices attribute here, but now we + // always validate indices since the speed difference was only 1.5%. + // TODO(irving): Remove the validate_indices attribute once we have + // support for removing attrs in a backwards compatible way. } void Compute(OpKernelContext* c) override { - const Tensor& Tparams = c->input(0); - const Tensor& Tindices = c->input(1); + const Tensor& params = c->input(0); + const Tensor& indices = c->input(1); OP_REQUIRES( - c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()), + c, TensorShapeUtils::IsVectorOrHigher(params.shape()), errors::InvalidArgument("params must be at least 1 dimensional")); - const int64 N = Tindices.NumElements(); - const int64 first_dim_size = Tparams.dim_size(0); - // Validate all the indices are in range - auto Tindices_vec = Tindices.flat(); - if (validate_indices_) { - for (int64 i = 0; i < N; i++) { - const Index index = Tindices_vec(i); - OP_REQUIRES(c, index >= 0 && index < first_dim_size, - errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in Tindices is out of range"))); - } - } + // Check that we have enough index space + const int64 N_big = indices.NumElements(); + OP_REQUIRES(c, N_big <= std::numeric_limits::max(), + errors::InvalidArgument( + "indices has too many elements for int indexing: ", N_big, + " > ", std::numeric_limits::max())); + const int N = indices.NumElements(); + OP_REQUIRES( + c, params.dim_size(0) <= std::numeric_limits::max(), + errors::InvalidArgument("params.shape[0] too large for ", + DataTypeString(DataTypeToEnum::v()), + " indexing: ", params.dim_size(0), " > ", + std::numeric_limits::max())); // The result shape is indices.shape + params.shape[1:]. - TensorShape result_shape = Tindices.shape(); - for (int i = 1; i < Tparams.dims(); i++) { - result_shape.AddDim(Tparams.dim_size(i)); + TensorShape result_shape = indices.shape(); + for (int i = 1; i < params.dims(); i++) { + result_shape.AddDim(params.dim_size(i)); } - Tensor* Tout = nullptr; - OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &Tout)); - const auto& Tparams_flat = Tparams.flat_outer_dims(); + Tensor* out = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out)); if (N > 0) { - auto Tindices_flat = Tindices.flat(); - auto Tout_flat = Tout->shaped({N, Tout->NumElements() / N}); - if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { - const int64 slice_size = Tout->NumElements() / N; -#define SPECIALIZE(elems) \ - do { \ - if (slice_size == elems) { \ - HandleCopies(Tparams, Tindices_flat, slice_size, \ - Tout_flat); \ - return; \ - } \ - } while (0) + auto indices_flat = indices.flat(); + auto out_flat = out->shaped({N, out->NumElements() / N}); + const int64 slice_size = out->NumElements() / N; + Index bad_i; - SPECIALIZE(10); - SPECIALIZE(20); +#define CALL(elems) \ + bad_i = HandleCopies(params, indices_flat, slice_size, \ + out_flat) -#undef SPECIALIZE + if (slice_size == 10) + CALL(10); + else if (slice_size == 20) + CALL(20); + else + CALL(-1); - HandleCopies(Tparams, Tindices_flat, slice_size, - Tout_flat); - } else { - for (int i = 0; i < N; i++) { - int j = i + 1; - if (j < N) { - port::prefetch( - &Tparams_flat(Tindices_vec(j), 0)); - port::prefetch(&Tout_flat(j, 0)); - } - // Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i] - Tout_flat.template chip<0>(i) = - Tparams_flat.template chip<0>(Tindices_vec(i)); - } - } +#undef CALL + + OP_REQUIRES( + c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), " = ", + indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); } } - - private: - bool validate_indices_; }; #define REGISTER_GATHER(type, index_type) \ diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 30fd105b5f6..518053cc0be 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" namespace tensorflow { @@ -99,36 +101,54 @@ class ScatterUpdateOp : public OpKernel { } void DoCompute(OpKernelContext* c) { - Tensor Tparams = c->mutable_input(0, use_exclusive_lock_); - OP_REQUIRES(c, Tparams.IsInitialized(), + Tensor params = c->mutable_input(0, use_exclusive_lock_); + OP_REQUIRES(c, params.IsInitialized(), errors::FailedPrecondition("Null ref for params")); - const Tensor& Tindices = c->input(1); - const Tensor& Tupdates = c->input(2); + const Tensor& indices = c->input(1); + const Tensor& updates = c->input(2); OP_REQUIRES( - c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()), + c, TensorShapeUtils::IsVectorOrHigher(params.shape()), errors::InvalidArgument("params must be at least 1-D, got shape ", - Tparams.shape().DebugString())); + params.shape().DebugString())); OP_REQUIRES( - c, ValidShapes(Tparams, Tupdates, Tindices), + c, ValidShapes(params, updates, indices), errors::InvalidArgument( "Must have updates.shape = indices.shape + params.shape[1:], got ", - "updates.shape ", Tupdates.shape().DebugString(), - ", indices.shape ", Tindices.shape().DebugString(), - ", params.shape ", Tparams.shape().DebugString())); + "updates.shape ", updates.shape().DebugString(), ", indices.shape ", + indices.shape().DebugString(), ", params.shape ", + params.shape().DebugString())); + + // Check that we have enough index space + const int64 N_big = indices.NumElements(); + OP_REQUIRES(c, N_big <= std::numeric_limits::max(), + errors::InvalidArgument( + "indices has too many elements for ", + DataTypeString(DataTypeToEnum::v()), " indexing: ", + N_big, " > ", std::numeric_limits::max())); + const Index N = indices.NumElements(); + OP_REQUIRES( + c, params.dim_size(0) <= std::numeric_limits::max(), + errors::InvalidArgument("params.shape[0] too large for ", + DataTypeString(DataTypeToEnum::v()), + " indexing: ", params.dim_size(0), " > ", + std::numeric_limits::max())); // We always return the input ref. c->forward_ref_input_to_ref_output(0, 0); - const Index N = Tindices.NumElements(); if (N > 0) { - auto Tindices_flat = Tindices.flat(); - auto Tparams_flat = Tparams.flat_outer_dims(); - auto Tupdates_flat = - Tupdates.shaped({N, Tupdates.NumElements() / N}); + auto indices_flat = indices.flat(); + auto params_flat = params.flat_outer_dims(); + auto updates_flat = updates.shaped({N, updates.NumElements() / N}); functor::ScatterFunctor functor; - functor(c, c->template eigen_device(), - Tparams_flat, Tupdates_flat, Tindices_flat); + const Index bad_i = functor(c, c->template eigen_device(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES( + c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), " = ", + indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); } } }; @@ -137,26 +157,23 @@ namespace functor { // Implementation of update functor for CPU. template struct ScatterFunctor { - void operator()(OpKernelContext* c, const CPUDevice& d, - typename TTypes::Matrix params, - typename TTypes::ConstMatrix updates, - typename TTypes::ConstFlat indices) { - Index N = indices.size(); - // Validate all the indices are in range - Index first_dim_size = params.dimension(0); + Index operator()(OpKernelContext* c, const CPUDevice& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { + const Index N = indices.size(); + const Index limit = params.dimension(0); for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. const Index index = indices(i); - OP_REQUIRES(c, index >= 0 && index < first_dim_size, - errors::InvalidArgument( - strings::StrCat("Index ", index, " at offset ", i, - " in indices is out of range"))); - } - for (Index i = 0; i < N; i++) { - // Copy last Ndim-1 dimensions of Tupdates[i] to - // Tparams[Tindices[i]] - Assign::Run(params.template chip<0>(indices(i)), + if (!FastBoundsCheck(index, limit)) return i; + // Copy last Ndim-1 dimensions of updates[i] to params[index] + Assign::Run(params.template chip<0>(index), updates.template chip<0>(i)); } + return -1; } }; } // namespace functor @@ -220,13 +237,13 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU); // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPECS_OP(T, Index, op) \ - template <> \ - void ScatterFunctor::operator()( \ - OpKernelContext* c, const GPUDevice& d, \ - typename TTypes::Matrix params, \ - typename TTypes::ConstMatrix updates, \ - typename TTypes::ConstFlat indices); \ +#define DECLARE_GPU_SPECS_OP(T, Index, op) \ + template <> \ + Index ScatterFunctor::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes::Matrix params, \ + typename TTypes::ConstMatrix updates, \ + typename TTypes::ConstFlat indices); \ extern template struct ScatterFunctor; #define DECLARE_GPU_SPECS_INDEX(T, Index) \ diff --git a/tensorflow/core/kernels/scatter_op.h b/tensorflow/core/kernels/scatter_op.h index b7c7df97a76..f17cfef9247 100644 --- a/tensorflow/core/kernels/scatter_op.h +++ b/tensorflow/core/kernels/scatter_op.h @@ -36,10 +36,11 @@ namespace functor { // Functor used by ScatterOp to do the computations. template struct ScatterFunctor { - void operator()(OpKernelContext* c, const Device& d, - typename TTypes::Matrix params, - typename TTypes::ConstMatrix updates, - typename TTypes::ConstFlat indices); + // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index. + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices); }; } // namespace functor diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc index 6ef23419aba..9543aedfde5 100644 --- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc @@ -62,10 +62,10 @@ namespace functor { // Specialization for a GPU device. template struct ScatterFunctor { - void operator()(OpKernelContext* c, const GPUDevice& d, - typename TTypes::Matrix params, - typename TTypes::ConstMatrix updates, - typename TTypes::ConstFlat indices) { + Index operator()(OpKernelContext* c, const GPUDevice& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { // TODO: Implement indices range check. The hardest part is with returning // a value after the range check, as we do not want to do device to host // memcpy during a stream. @@ -77,6 +77,7 @@ struct ScatterFunctor { <<>>( params.data(), updates.data(), indices.data(), first_dim_size, updates_size, indices_size); + return -1; } }; diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index 59ce04681eb..5292a5bbad0 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -83,6 +83,14 @@ class GatherTest(tf.test.TestCase): gather_t = tf.gather(params, indices) self.assertEqual(None, gather_t.get_shape()) + def testBadIndices(self): + with self.test_session(): + params = [0, 1, 2] + indices = [[7]] + gather = tf.gather(params, indices) + with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): + gather.eval() + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py index 357c8f1a983..674ce2b2d11 100644 --- a/tensorflow/python/kernel_tests/scatter_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase): # Test some out of range errors. indices = np.array([-1, 0, 5]) - with self.assertRaisesOpError('indices is out of range'): + with self.assertRaisesOpError(r'indices\[0\] = -1 is not in \[0, 6\)'): op(ref, indices, updates).eval() indices = np.array([2, 0, 6]) - with self.assertRaisesOpError('indices is out of range'): + with self.assertRaisesOpError(r'indices\[2\] = 6 is not in \[0, 6\)'): op(ref, indices, updates).eval() # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.