Eliminate unneded pylint disable
Change: 115470945
This commit is contained in:
parent
14a237beb0
commit
6b2c0012d1
@ -23,5 +23,4 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
|
||||
|
@ -21,49 +21,33 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/bounds_check.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// Returns -1 on success or a nonnegative i s.t., indices[i] is bad.
|
||||
template <typename T, typename Index, int static_slice_elems>
|
||||
Index HandleCopies(const Tensor& params,
|
||||
typename TTypes<Index>::ConstVec& indices, Index slice_elems,
|
||||
typename TTypes<T>::Matrix out) {
|
||||
const int N = indices.dimension(0);
|
||||
const auto& params_flat = params.flat_outer_dims<T>();
|
||||
const Index limit = params.dim_size(0);
|
||||
T* out_base = &out(0, 0);
|
||||
const T* params_base = ¶ms_flat(0, 0);
|
||||
void HandleCopies(const Tensor& Tparams,
|
||||
typename TTypes<Index>::ConstVec& Tindices, int slice_elems,
|
||||
typename TTypes<T>::Matrix Tout) {
|
||||
const int N = Tindices.dimension(0);
|
||||
const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
|
||||
T* Tout_base = &Tout(0, 0);
|
||||
const T* Tparams_base = &Tparams_flat(0, 0);
|
||||
const size_t slice_bytes = slice_elems * sizeof(T);
|
||||
if (static_slice_elems >= 0) {
|
||||
// Give compiler static knowledge of the number of elements/bytes
|
||||
CHECK_EQ(static_slice_elems, slice_elems);
|
||||
slice_elems = static_slice_elems;
|
||||
}
|
||||
// Compute slice_bytes here so that static knowledge is available
|
||||
const size_t slice_bytes = slice_elems * sizeof(T);
|
||||
for (int i = 0; i < N; i++) {
|
||||
const int j = i + 1;
|
||||
int j = i + 1;
|
||||
if (j < N) {
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(¶ms_flat(indices(j), 0));
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
|
||||
}
|
||||
// Grab the index and check its validity. An earlier version of the
|
||||
// code checked it and then grabbed it from memory a second time, which
|
||||
// was a security risk since it could have changed in between.
|
||||
const Index index = indices(i);
|
||||
if (!FastBoundsCheck(index, limit)) return i;
|
||||
// Copy using memcpy if possible, otherwise an Eigen loop
|
||||
if (Allocator::is_simple<T>::value) {
|
||||
memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
|
||||
slice_bytes);
|
||||
} else {
|
||||
out.template chip<0>(i) = params_flat.template chip<0>(index);
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(&Tparams_flat(Tindices(j), 0));
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(&Tout(j, 0));
|
||||
}
|
||||
memcpy(Tout_base + i * slice_elems,
|
||||
Tparams_base + Tindices(i) * slice_elems, slice_bytes);
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
@ -80,67 +64,78 @@ class GatherOp : public OpKernel {
|
||||
const DataType dt = DataTypeToEnum<T>::v();
|
||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
|
||||
// We used to grab the validate_indices attribute here, but now we
|
||||
// always validate indices since the speed difference was only 1.5%.
|
||||
// TODO(irving): Remove the validate_indices attribute once we have
|
||||
// support for removing attrs in a backwards compatible way.
|
||||
OP_REQUIRES_OK(c, c->GetAttr("validate_indices", &validate_indices_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const Tensor& params = c->input(0);
|
||||
const Tensor& indices = c->input(1);
|
||||
const Tensor& Tparams = c->input(0);
|
||||
const Tensor& Tindices = c->input(1);
|
||||
OP_REQUIRES(
|
||||
c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
|
||||
c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()),
|
||||
errors::InvalidArgument("params must be at least 1 dimensional"));
|
||||
const int64 N = Tindices.NumElements();
|
||||
const int64 first_dim_size = Tparams.dim_size(0);
|
||||
|
||||
// Check that we have enough index space
|
||||
const int64 N_big = indices.NumElements();
|
||||
OP_REQUIRES(c, N_big <= std::numeric_limits<int>::max(),
|
||||
errors::InvalidArgument(
|
||||
"indices has too many elements for int indexing: ", N_big,
|
||||
" > ", std::numeric_limits<int>::max()));
|
||||
const int N = indices.NumElements();
|
||||
OP_REQUIRES(
|
||||
c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
|
||||
errors::InvalidArgument("params.shape[0] too large for ",
|
||||
DataTypeString(DataTypeToEnum<Index>::v()),
|
||||
" indexing: ", params.dim_size(0), " > ",
|
||||
std::numeric_limits<Index>::max()));
|
||||
// Validate all the indices are in range
|
||||
auto Tindices_vec = Tindices.flat<Index>();
|
||||
if (validate_indices_) {
|
||||
for (int64 i = 0; i < N; i++) {
|
||||
const Index index = Tindices_vec(i);
|
||||
OP_REQUIRES(c, index >= 0 && index < first_dim_size,
|
||||
errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in Tindices is out of range")));
|
||||
}
|
||||
}
|
||||
|
||||
// The result shape is indices.shape + params.shape[1:].
|
||||
TensorShape result_shape = indices.shape();
|
||||
for (int i = 1; i < params.dims(); i++) {
|
||||
result_shape.AddDim(params.dim_size(i));
|
||||
TensorShape result_shape = Tindices.shape();
|
||||
for (int i = 1; i < Tparams.dims(); i++) {
|
||||
result_shape.AddDim(Tparams.dim_size(i));
|
||||
}
|
||||
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
|
||||
Tensor* Tout = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &Tout));
|
||||
const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
|
||||
if (N > 0) {
|
||||
auto indices_flat = indices.flat<Index>();
|
||||
auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N});
|
||||
const int64 slice_size = out->NumElements() / N;
|
||||
Index bad_i;
|
||||
auto Tindices_flat = Tindices.flat<Index>();
|
||||
auto Tout_flat = Tout->shaped<T, 2>({N, Tout->NumElements() / N});
|
||||
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
|
||||
const int64 slice_size = Tout->NumElements() / N;
|
||||
#define SPECIALIZE(elems) \
|
||||
do { \
|
||||
if (slice_size == elems) { \
|
||||
HandleCopies<T, Index, elems>(Tparams, Tindices_flat, slice_size, \
|
||||
Tout_flat); \
|
||||
return; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CALL(elems) \
|
||||
bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \
|
||||
out_flat)
|
||||
SPECIALIZE(10);
|
||||
SPECIALIZE(20);
|
||||
|
||||
if (slice_size == 10)
|
||||
CALL(10);
|
||||
else if (slice_size == 20)
|
||||
CALL(20);
|
||||
else
|
||||
CALL(-1);
|
||||
#undef SPECIALIZE
|
||||
|
||||
#undef CALL
|
||||
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
||||
indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
|
||||
HandleCopies<T, Index, -1>(Tparams, Tindices_flat, slice_size,
|
||||
Tout_flat);
|
||||
} else {
|
||||
for (int i = 0; i < N; i++) {
|
||||
int j = i + 1;
|
||||
if (j < N) {
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(
|
||||
&Tparams_flat(Tindices_vec(j), 0));
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(&Tout_flat(j, 0));
|
||||
}
|
||||
// Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i]
|
||||
Tout_flat.template chip<0>(i) =
|
||||
Tparams_flat.template chip<0>(Tindices_vec(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool validate_indices_;
|
||||
};
|
||||
|
||||
#define REGISTER_GATHER(type, index_type) \
|
||||
|
@ -22,8 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/bounds_check.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -101,54 +99,36 @@ class ScatterUpdateOp : public OpKernel {
|
||||
}
|
||||
|
||||
void DoCompute(OpKernelContext* c) {
|
||||
Tensor params = c->mutable_input(0, use_exclusive_lock_);
|
||||
OP_REQUIRES(c, params.IsInitialized(),
|
||||
Tensor Tparams = c->mutable_input(0, use_exclusive_lock_);
|
||||
OP_REQUIRES(c, Tparams.IsInitialized(),
|
||||
errors::FailedPrecondition("Null ref for params"));
|
||||
const Tensor& indices = c->input(1);
|
||||
const Tensor& updates = c->input(2);
|
||||
const Tensor& Tindices = c->input(1);
|
||||
const Tensor& Tupdates = c->input(2);
|
||||
OP_REQUIRES(
|
||||
c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
|
||||
c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()),
|
||||
errors::InvalidArgument("params must be at least 1-D, got shape ",
|
||||
params.shape().DebugString()));
|
||||
Tparams.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
c, ValidShapes(params, updates, indices),
|
||||
c, ValidShapes(Tparams, Tupdates, Tindices),
|
||||
errors::InvalidArgument(
|
||||
"Must have updates.shape = indices.shape + params.shape[1:], got ",
|
||||
"updates.shape ", updates.shape().DebugString(), ", indices.shape ",
|
||||
indices.shape().DebugString(), ", params.shape ",
|
||||
params.shape().DebugString()));
|
||||
|
||||
// Check that we have enough index space
|
||||
const int64 N_big = indices.NumElements();
|
||||
OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
|
||||
errors::InvalidArgument(
|
||||
"indices has too many elements for ",
|
||||
DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ",
|
||||
N_big, " > ", std::numeric_limits<Index>::max()));
|
||||
const Index N = indices.NumElements();
|
||||
OP_REQUIRES(
|
||||
c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
|
||||
errors::InvalidArgument("params.shape[0] too large for ",
|
||||
DataTypeString(DataTypeToEnum<Index>::v()),
|
||||
" indexing: ", params.dim_size(0), " > ",
|
||||
std::numeric_limits<Index>::max()));
|
||||
"updates.shape ", Tupdates.shape().DebugString(),
|
||||
", indices.shape ", Tindices.shape().DebugString(),
|
||||
", params.shape ", Tparams.shape().DebugString()));
|
||||
|
||||
// We always return the input ref.
|
||||
c->forward_ref_input_to_ref_output(0, 0);
|
||||
|
||||
const Index N = Tindices.NumElements();
|
||||
if (N > 0) {
|
||||
auto indices_flat = indices.flat<Index>();
|
||||
auto params_flat = params.flat_outer_dims<T>();
|
||||
auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
|
||||
auto Tindices_flat = Tindices.flat<Index>();
|
||||
auto Tparams_flat = Tparams.flat_outer_dims<T>();
|
||||
auto Tupdates_flat =
|
||||
Tupdates.shaped<T, 2>({N, Tupdates.NumElements() / N});
|
||||
|
||||
functor::ScatterFunctor<Device, T, Index, op> functor;
|
||||
const Index bad_i = functor(c, c->template eigen_device<Device>(),
|
||||
params_flat, updates_flat, indices_flat);
|
||||
OP_REQUIRES(
|
||||
c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
||||
indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
|
||||
functor(c, c->template eigen_device<Device>(),
|
||||
Tparams_flat, Tupdates_flat, Tindices_flat);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -157,23 +137,26 @@ namespace functor {
|
||||
// Implementation of update functor for CPU.
|
||||
template <typename T, typename Index, scatter_op::UpdateOp op>
|
||||
struct ScatterFunctor<CPUDevice, T, Index, op> {
|
||||
Index operator()(OpKernelContext* c, const CPUDevice& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices) {
|
||||
const Index N = indices.size();
|
||||
const Index limit = params.dimension(0);
|
||||
void operator()(OpKernelContext* c, const CPUDevice& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices) {
|
||||
Index N = indices.size();
|
||||
// Validate all the indices are in range
|
||||
Index first_dim_size = params.dimension(0);
|
||||
for (Index i = 0; i < N; i++) {
|
||||
// Grab the index and check its validity. An earlier version of the
|
||||
// code checked it and then grabbed it from memory a second time, which
|
||||
// was a security risk since it could have changed in between.
|
||||
const Index index = indices(i);
|
||||
if (!FastBoundsCheck(index, limit)) return i;
|
||||
// Copy last Ndim-1 dimensions of updates[i] to params[index]
|
||||
Assign<op>::Run(params.template chip<0>(index),
|
||||
OP_REQUIRES(c, index >= 0 && index < first_dim_size,
|
||||
errors::InvalidArgument(
|
||||
strings::StrCat("Index ", index, " at offset ", i,
|
||||
" in indices is out of range")));
|
||||
}
|
||||
for (Index i = 0; i < N; i++) {
|
||||
// Copy last Ndim-1 dimensions of Tupdates[i] to
|
||||
// Tparams[Tindices[i]]
|
||||
Assign<op>::Run(params.template chip<0>(indices(i)),
|
||||
updates.template chip<0>(i));
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
@ -237,13 +220,13 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
|
||||
#define DECLARE_GPU_SPECS_OP(T, Index, op) \
|
||||
template <> \
|
||||
Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
|
||||
OpKernelContext* c, const GPUDevice& d, \
|
||||
typename TTypes<T>::Matrix params, \
|
||||
typename TTypes<T>::ConstMatrix updates, \
|
||||
typename TTypes<Index>::ConstFlat indices); \
|
||||
#define DECLARE_GPU_SPECS_OP(T, Index, op) \
|
||||
template <> \
|
||||
void ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
|
||||
OpKernelContext* c, const GPUDevice& d, \
|
||||
typename TTypes<T>::Matrix params, \
|
||||
typename TTypes<T>::ConstMatrix updates, \
|
||||
typename TTypes<Index>::ConstFlat indices); \
|
||||
extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
|
||||
|
||||
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||
|
@ -36,11 +36,10 @@ namespace functor {
|
||||
// Functor used by ScatterOp to do the computations.
|
||||
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
|
||||
struct ScatterFunctor {
|
||||
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
|
||||
Index operator()(OpKernelContext* c, const Device& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices);
|
||||
void operator()(OpKernelContext* c, const Device& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
@ -62,10 +62,10 @@ namespace functor {
|
||||
// Specialization for a GPU device.
|
||||
template <typename T, typename Index, scatter_op::UpdateOp op>
|
||||
struct ScatterFunctor<GPUDevice, T, Index, op> {
|
||||
Index operator()(OpKernelContext* c, const GPUDevice& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices) {
|
||||
void operator()(OpKernelContext* c, const GPUDevice& d,
|
||||
typename TTypes<T>::Matrix params,
|
||||
typename TTypes<T>::ConstMatrix updates,
|
||||
typename TTypes<Index>::ConstFlat indices) {
|
||||
// TODO: Implement indices range check. The hardest part is with returning
|
||||
// a value after the range check, as we do not want to do device to host
|
||||
// memcpy during a stream.
|
||||
@ -77,7 +77,6 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
params.data(), updates.data(), indices.data(),
|
||||
first_dim_size, updates_size, indices_size);
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1,38 +0,0 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_
|
||||
#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Check that 0 <= index < limit using a single comparison, assuming
|
||||
// that 0 <= limit if Index is signed. Intended for use in performance
|
||||
// critical contexts where 0 <= index < limit is almost always true.
|
||||
template <class Index>
|
||||
EIGEN_ALWAYS_INLINE bool FastBoundsCheck(Index index, Index limit) {
|
||||
typedef typename std::make_unsigned<Index>::type UIndex;
|
||||
return TF_PREDICT_TRUE(static_cast<UIndex>(index) <
|
||||
static_cast<UIndex>(limit));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_
|
@ -83,14 +83,6 @@ class GatherTest(tf.test.TestCase):
|
||||
gather_t = tf.gather(params, indices)
|
||||
self.assertEqual(None, gather_t.get_shape())
|
||||
|
||||
def testBadIndices(self):
|
||||
with self.test_session():
|
||||
params = [0, 1, 2]
|
||||
indices = [[7]]
|
||||
gather = tf.gather(params, indices)
|
||||
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
|
||||
gather.eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase):
|
||||
|
||||
# Test some out of range errors.
|
||||
indices = np.array([-1, 0, 5])
|
||||
with self.assertRaisesOpError(r'indices\[0\] = -1 is not in \[0, 6\)'):
|
||||
with self.assertRaisesOpError('indices is out of range'):
|
||||
op(ref, indices, updates).eval()
|
||||
|
||||
indices = np.array([2, 0, 6])
|
||||
with self.assertRaisesOpError(r'indices\[2\] = 6 is not in \[0, 6\)'):
|
||||
with self.assertRaisesOpError('indices is out of range'):
|
||||
op(ref, indices, updates).eval()
|
||||
|
||||
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
|
||||
|
Loading…
Reference in New Issue
Block a user