Eliminate unneded pylint disable
Change: 115470945
This commit is contained in:
parent
14a237beb0
commit
6b2c0012d1
@ -23,5 +23,4 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=wildcard-import
|
|
||||||
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
|
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
|
||||||
|
@ -21,50 +21,34 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/mem.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/bounds_check.h"
|
|
||||||
#include "tensorflow/core/util/util.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Returns -1 on success or a nonnegative i s.t., indices[i] is bad.
|
|
||||||
template <typename T, typename Index, int static_slice_elems>
|
template <typename T, typename Index, int static_slice_elems>
|
||||||
Index HandleCopies(const Tensor& params,
|
void HandleCopies(const Tensor& Tparams,
|
||||||
typename TTypes<Index>::ConstVec& indices, Index slice_elems,
|
typename TTypes<Index>::ConstVec& Tindices, int slice_elems,
|
||||||
typename TTypes<T>::Matrix out) {
|
typename TTypes<T>::Matrix Tout) {
|
||||||
const int N = indices.dimension(0);
|
const int N = Tindices.dimension(0);
|
||||||
const auto& params_flat = params.flat_outer_dims<T>();
|
const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
|
||||||
const Index limit = params.dim_size(0);
|
T* Tout_base = &Tout(0, 0);
|
||||||
T* out_base = &out(0, 0);
|
const T* Tparams_base = &Tparams_flat(0, 0);
|
||||||
const T* params_base = ¶ms_flat(0, 0);
|
const size_t slice_bytes = slice_elems * sizeof(T);
|
||||||
if (static_slice_elems >= 0) {
|
if (static_slice_elems >= 0) {
|
||||||
// Give compiler static knowledge of the number of elements/bytes
|
// Give compiler static knowledge of the number of elements/bytes
|
||||||
CHECK_EQ(static_slice_elems, slice_elems);
|
CHECK_EQ(static_slice_elems, slice_elems);
|
||||||
slice_elems = static_slice_elems;
|
slice_elems = static_slice_elems;
|
||||||
}
|
}
|
||||||
// Compute slice_bytes here so that static knowledge is available
|
|
||||||
const size_t slice_bytes = slice_elems * sizeof(T);
|
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
const int j = i + 1;
|
int j = i + 1;
|
||||||
if (j < N) {
|
if (j < N) {
|
||||||
port::prefetch<port::PREFETCH_HINT_T0>(¶ms_flat(indices(j), 0));
|
port::prefetch<port::PREFETCH_HINT_T0>(&Tparams_flat(Tindices(j), 0));
|
||||||
port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
|
port::prefetch<port::PREFETCH_HINT_T0>(&Tout(j, 0));
|
||||||
}
|
}
|
||||||
// Grab the index and check its validity. An earlier version of the
|
memcpy(Tout_base + i * slice_elems,
|
||||||
// code checked it and then grabbed it from memory a second time, which
|
Tparams_base + Tindices(i) * slice_elems, slice_bytes);
|
||||||
// was a security risk since it could have changed in between.
|
|
||||||
const Index index = indices(i);
|
|
||||||
if (!FastBoundsCheck(index, limit)) return i;
|
|
||||||
// Copy using memcpy if possible, otherwise an Eigen loop
|
|
||||||
if (Allocator::is_simple<T>::value) {
|
|
||||||
memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
|
|
||||||
slice_bytes);
|
|
||||||
} else {
|
|
||||||
out.template chip<0>(i) = params_flat.template chip<0>(index);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
@ -80,67 +64,78 @@ class GatherOp : public OpKernel {
|
|||||||
const DataType dt = DataTypeToEnum<T>::v();
|
const DataType dt = DataTypeToEnum<T>::v();
|
||||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||||
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
|
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
|
||||||
// We used to grab the validate_indices attribute here, but now we
|
OP_REQUIRES_OK(c, c->GetAttr("validate_indices", &validate_indices_));
|
||||||
// always validate indices since the speed difference was only 1.5%.
|
|
||||||
// TODO(irving): Remove the validate_indices attribute once we have
|
|
||||||
// support for removing attrs in a backwards compatible way.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
const Tensor& params = c->input(0);
|
const Tensor& Tparams = c->input(0);
|
||||||
const Tensor& indices = c->input(1);
|
const Tensor& Tindices = c->input(1);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
|
c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()),
|
||||||
errors::InvalidArgument("params must be at least 1 dimensional"));
|
errors::InvalidArgument("params must be at least 1 dimensional"));
|
||||||
|
const int64 N = Tindices.NumElements();
|
||||||
|
const int64 first_dim_size = Tparams.dim_size(0);
|
||||||
|
|
||||||
// Check that we have enough index space
|
// Validate all the indices are in range
|
||||||
const int64 N_big = indices.NumElements();
|
auto Tindices_vec = Tindices.flat<Index>();
|
||||||
OP_REQUIRES(c, N_big <= std::numeric_limits<int>::max(),
|
if (validate_indices_) {
|
||||||
|
for (int64 i = 0; i < N; i++) {
|
||||||
|
const Index index = Tindices_vec(i);
|
||||||
|
OP_REQUIRES(c, index >= 0 && index < first_dim_size,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"indices has too many elements for int indexing: ", N_big,
|
strings::StrCat("Index ", index, " at offset ", i,
|
||||||
" > ", std::numeric_limits<int>::max()));
|
" in Tindices is out of range")));
|
||||||
const int N = indices.NumElements();
|
}
|
||||||
OP_REQUIRES(
|
}
|
||||||
c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
|
|
||||||
errors::InvalidArgument("params.shape[0] too large for ",
|
|
||||||
DataTypeString(DataTypeToEnum<Index>::v()),
|
|
||||||
" indexing: ", params.dim_size(0), " > ",
|
|
||||||
std::numeric_limits<Index>::max()));
|
|
||||||
|
|
||||||
// The result shape is indices.shape + params.shape[1:].
|
// The result shape is indices.shape + params.shape[1:].
|
||||||
TensorShape result_shape = indices.shape();
|
TensorShape result_shape = Tindices.shape();
|
||||||
for (int i = 1; i < params.dims(); i++) {
|
for (int i = 1; i < Tparams.dims(); i++) {
|
||||||
result_shape.AddDim(params.dim_size(i));
|
result_shape.AddDim(Tparams.dim_size(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor* out = nullptr;
|
Tensor* Tout = nullptr;
|
||||||
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
|
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &Tout));
|
||||||
|
const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
|
||||||
if (N > 0) {
|
if (N > 0) {
|
||||||
auto indices_flat = indices.flat<Index>();
|
auto Tindices_flat = Tindices.flat<Index>();
|
||||||
auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N});
|
auto Tout_flat = Tout->shaped<T, 2>({N, Tout->NumElements() / N});
|
||||||
const int64 slice_size = out->NumElements() / N;
|
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
|
||||||
Index bad_i;
|
const int64 slice_size = Tout->NumElements() / N;
|
||||||
|
#define SPECIALIZE(elems) \
|
||||||
|
do { \
|
||||||
|
if (slice_size == elems) { \
|
||||||
|
HandleCopies<T, Index, elems>(Tparams, Tindices_flat, slice_size, \
|
||||||
|
Tout_flat); \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
#define CALL(elems) \
|
SPECIALIZE(10);
|
||||||
bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \
|
SPECIALIZE(20);
|
||||||
out_flat)
|
|
||||||
|
|
||||||
if (slice_size == 10)
|
#undef SPECIALIZE
|
||||||
CALL(10);
|
|
||||||
else if (slice_size == 20)
|
|
||||||
CALL(20);
|
|
||||||
else
|
|
||||||
CALL(-1);
|
|
||||||
|
|
||||||
#undef CALL
|
HandleCopies<T, Index, -1>(Tparams, Tindices_flat, slice_size,
|
||||||
|
Tout_flat);
|
||||||
OP_REQUIRES(
|
} else {
|
||||||
c, bad_i < 0,
|
for (int i = 0; i < N; i++) {
|
||||||
errors::InvalidArgument(
|
int j = i + 1;
|
||||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
if (j < N) {
|
||||||
indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
|
port::prefetch<port::PREFETCH_HINT_T0>(
|
||||||
|
&Tparams_flat(Tindices_vec(j), 0));
|
||||||
|
port::prefetch<port::PREFETCH_HINT_T0>(&Tout_flat(j, 0));
|
||||||
|
}
|
||||||
|
// Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i]
|
||||||
|
Tout_flat.template chip<0>(i) =
|
||||||
|
Tparams_flat.template chip<0>(Tindices_vec(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool validate_indices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_GATHER(type, index_type) \
|
#define REGISTER_GATHER(type, index_type) \
|
||||||
|
@ -22,8 +22,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/bounds_check.h"
|
|
||||||
#include "tensorflow/core/util/util.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -101,54 +99,36 @@ class ScatterUpdateOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void DoCompute(OpKernelContext* c) {
|
void DoCompute(OpKernelContext* c) {
|
||||||
Tensor params = c->mutable_input(0, use_exclusive_lock_);
|
Tensor Tparams = c->mutable_input(0, use_exclusive_lock_);
|
||||||
OP_REQUIRES(c, params.IsInitialized(),
|
OP_REQUIRES(c, Tparams.IsInitialized(),
|
||||||
errors::FailedPrecondition("Null ref for params"));
|
errors::FailedPrecondition("Null ref for params"));
|
||||||
const Tensor& indices = c->input(1);
|
const Tensor& Tindices = c->input(1);
|
||||||
const Tensor& updates = c->input(2);
|
const Tensor& Tupdates = c->input(2);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
|
c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()),
|
||||||
errors::InvalidArgument("params must be at least 1-D, got shape ",
|
errors::InvalidArgument("params must be at least 1-D, got shape ",
|
||||||
params.shape().DebugString()));
|
Tparams.shape().DebugString()));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, ValidShapes(params, updates, indices),
|
c, ValidShapes(Tparams, Tupdates, Tindices),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Must have updates.shape = indices.shape + params.shape[1:], got ",
|
"Must have updates.shape = indices.shape + params.shape[1:], got ",
|
||||||
"updates.shape ", updates.shape().DebugString(), ", indices.shape ",
|
"updates.shape ", Tupdates.shape().DebugString(),
|
||||||
indices.shape().DebugString(), ", params.shape ",
|
", indices.shape ", Tindices.shape().DebugString(),
|
||||||
params.shape().DebugString()));
|
", params.shape ", Tparams.shape().DebugString()));
|
||||||
|
|
||||||
// Check that we have enough index space
|
|
||||||
const int64 N_big = indices.NumElements();
|
|
||||||
OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"indices has too many elements for ",
|
|
||||||
DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ",
|
|
||||||
N_big, " > ", std::numeric_limits<Index>::max()));
|
|
||||||
const Index N = indices.NumElements();
|
|
||||||
OP_REQUIRES(
|
|
||||||
c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
|
|
||||||
errors::InvalidArgument("params.shape[0] too large for ",
|
|
||||||
DataTypeString(DataTypeToEnum<Index>::v()),
|
|
||||||
" indexing: ", params.dim_size(0), " > ",
|
|
||||||
std::numeric_limits<Index>::max()));
|
|
||||||
|
|
||||||
// We always return the input ref.
|
// We always return the input ref.
|
||||||
c->forward_ref_input_to_ref_output(0, 0);
|
c->forward_ref_input_to_ref_output(0, 0);
|
||||||
|
|
||||||
|
const Index N = Tindices.NumElements();
|
||||||
if (N > 0) {
|
if (N > 0) {
|
||||||
auto indices_flat = indices.flat<Index>();
|
auto Tindices_flat = Tindices.flat<Index>();
|
||||||
auto params_flat = params.flat_outer_dims<T>();
|
auto Tparams_flat = Tparams.flat_outer_dims<T>();
|
||||||
auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
|
auto Tupdates_flat =
|
||||||
|
Tupdates.shaped<T, 2>({N, Tupdates.NumElements() / N});
|
||||||
|
|
||||||
functor::ScatterFunctor<Device, T, Index, op> functor;
|
functor::ScatterFunctor<Device, T, Index, op> functor;
|
||||||
const Index bad_i = functor(c, c->template eigen_device<Device>(),
|
functor(c, c->template eigen_device<Device>(),
|
||||||
params_flat, updates_flat, indices_flat);
|
Tparams_flat, Tupdates_flat, Tindices_flat);
|
||||||
OP_REQUIRES(
|
|
||||||
c, bad_i < 0,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
|
|
||||||
indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -157,23 +137,26 @@ namespace functor {
|
|||||||
// Implementation of update functor for CPU.
|
// Implementation of update functor for CPU.
|
||||||
template <typename T, typename Index, scatter_op::UpdateOp op>
|
template <typename T, typename Index, scatter_op::UpdateOp op>
|
||||||
struct ScatterFunctor<CPUDevice, T, Index, op> {
|
struct ScatterFunctor<CPUDevice, T, Index, op> {
|
||||||
Index operator()(OpKernelContext* c, const CPUDevice& d,
|
void operator()(OpKernelContext* c, const CPUDevice& d,
|
||||||
typename TTypes<T>::Matrix params,
|
typename TTypes<T>::Matrix params,
|
||||||
typename TTypes<T>::ConstMatrix updates,
|
typename TTypes<T>::ConstMatrix updates,
|
||||||
typename TTypes<Index>::ConstFlat indices) {
|
typename TTypes<Index>::ConstFlat indices) {
|
||||||
const Index N = indices.size();
|
Index N = indices.size();
|
||||||
const Index limit = params.dimension(0);
|
// Validate all the indices are in range
|
||||||
|
Index first_dim_size = params.dimension(0);
|
||||||
for (Index i = 0; i < N; i++) {
|
for (Index i = 0; i < N; i++) {
|
||||||
// Grab the index and check its validity. An earlier version of the
|
|
||||||
// code checked it and then grabbed it from memory a second time, which
|
|
||||||
// was a security risk since it could have changed in between.
|
|
||||||
const Index index = indices(i);
|
const Index index = indices(i);
|
||||||
if (!FastBoundsCheck(index, limit)) return i;
|
OP_REQUIRES(c, index >= 0 && index < first_dim_size,
|
||||||
// Copy last Ndim-1 dimensions of updates[i] to params[index]
|
errors::InvalidArgument(
|
||||||
Assign<op>::Run(params.template chip<0>(index),
|
strings::StrCat("Index ", index, " at offset ", i,
|
||||||
|
" in indices is out of range")));
|
||||||
|
}
|
||||||
|
for (Index i = 0; i < N; i++) {
|
||||||
|
// Copy last Ndim-1 dimensions of Tupdates[i] to
|
||||||
|
// Tparams[Tindices[i]]
|
||||||
|
Assign<op>::Run(params.template chip<0>(indices(i)),
|
||||||
updates.template chip<0>(i));
|
updates.template chip<0>(i));
|
||||||
}
|
}
|
||||||
return -1;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
@ -239,7 +222,7 @@ namespace functor {
|
|||||||
|
|
||||||
#define DECLARE_GPU_SPECS_OP(T, Index, op) \
|
#define DECLARE_GPU_SPECS_OP(T, Index, op) \
|
||||||
template <> \
|
template <> \
|
||||||
Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
|
void ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
|
||||||
OpKernelContext* c, const GPUDevice& d, \
|
OpKernelContext* c, const GPUDevice& d, \
|
||||||
typename TTypes<T>::Matrix params, \
|
typename TTypes<T>::Matrix params, \
|
||||||
typename TTypes<T>::ConstMatrix updates, \
|
typename TTypes<T>::ConstMatrix updates, \
|
||||||
|
@ -36,8 +36,7 @@ namespace functor {
|
|||||||
// Functor used by ScatterOp to do the computations.
|
// Functor used by ScatterOp to do the computations.
|
||||||
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
|
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
|
||||||
struct ScatterFunctor {
|
struct ScatterFunctor {
|
||||||
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
|
void operator()(OpKernelContext* c, const Device& d,
|
||||||
Index operator()(OpKernelContext* c, const Device& d,
|
|
||||||
typename TTypes<T>::Matrix params,
|
typename TTypes<T>::Matrix params,
|
||||||
typename TTypes<T>::ConstMatrix updates,
|
typename TTypes<T>::ConstMatrix updates,
|
||||||
typename TTypes<Index>::ConstFlat indices);
|
typename TTypes<Index>::ConstFlat indices);
|
||||||
|
@ -62,7 +62,7 @@ namespace functor {
|
|||||||
// Specialization for a GPU device.
|
// Specialization for a GPU device.
|
||||||
template <typename T, typename Index, scatter_op::UpdateOp op>
|
template <typename T, typename Index, scatter_op::UpdateOp op>
|
||||||
struct ScatterFunctor<GPUDevice, T, Index, op> {
|
struct ScatterFunctor<GPUDevice, T, Index, op> {
|
||||||
Index operator()(OpKernelContext* c, const GPUDevice& d,
|
void operator()(OpKernelContext* c, const GPUDevice& d,
|
||||||
typename TTypes<T>::Matrix params,
|
typename TTypes<T>::Matrix params,
|
||||||
typename TTypes<T>::ConstMatrix updates,
|
typename TTypes<T>::ConstMatrix updates,
|
||||||
typename TTypes<Index>::ConstFlat indices) {
|
typename TTypes<Index>::ConstFlat indices) {
|
||||||
@ -77,7 +77,6 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
|
|||||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
params.data(), updates.data(), indices.data(),
|
params.data(), updates.data(), indices.data(),
|
||||||
first_dim_size, updates_size, indices_size);
|
first_dim_size, updates_size, indices_size);
|
||||||
return -1;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_
|
|
||||||
#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_
|
|
||||||
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
// Check that 0 <= index < limit using a single comparison, assuming
|
|
||||||
// that 0 <= limit if Index is signed. Intended for use in performance
|
|
||||||
// critical contexts where 0 <= index < limit is almost always true.
|
|
||||||
template <class Index>
|
|
||||||
EIGEN_ALWAYS_INLINE bool FastBoundsCheck(Index index, Index limit) {
|
|
||||||
typedef typename std::make_unsigned<Index>::type UIndex;
|
|
||||||
return TF_PREDICT_TRUE(static_cast<UIndex>(index) <
|
|
||||||
static_cast<UIndex>(limit));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_
|
|
@ -83,14 +83,6 @@ class GatherTest(tf.test.TestCase):
|
|||||||
gather_t = tf.gather(params, indices)
|
gather_t = tf.gather(params, indices)
|
||||||
self.assertEqual(None, gather_t.get_shape())
|
self.assertEqual(None, gather_t.get_shape())
|
||||||
|
|
||||||
def testBadIndices(self):
|
|
||||||
with self.test_session():
|
|
||||||
params = [0, 1, 2]
|
|
||||||
indices = [[7]]
|
|
||||||
gather = tf.gather(params, indices)
|
|
||||||
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
|
|
||||||
gather.eval()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# Test some out of range errors.
|
# Test some out of range errors.
|
||||||
indices = np.array([-1, 0, 5])
|
indices = np.array([-1, 0, 5])
|
||||||
with self.assertRaisesOpError(r'indices\[0\] = -1 is not in \[0, 6\)'):
|
with self.assertRaisesOpError('indices is out of range'):
|
||||||
op(ref, indices, updates).eval()
|
op(ref, indices, updates).eval()
|
||||||
|
|
||||||
indices = np.array([2, 0, 6])
|
indices = np.array([2, 0, 6])
|
||||||
with self.assertRaisesOpError(r'indices\[2\] = 6 is not in \[0, 6\)'):
|
with self.assertRaisesOpError('indices is out of range'):
|
||||||
op(ref, indices, updates).eval()
|
op(ref, indices, updates).eval()
|
||||||
|
|
||||||
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
|
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
|
||||||
|
Loading…
Reference in New Issue
Block a user