Eigen BiasAdd and BiasAddGrad Fix for NCHW Format. (#13158)
This commit is contained in:
parent
01854b6d40
commit
8e22eb8748
@ -39,6 +39,48 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
namespace {
|
||||
|
||||
void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
|
||||
int32* batch, int32* height, int32* width,
|
||||
int32* channel) {
|
||||
*batch = 1;
|
||||
*width = 1;
|
||||
*height = 1;
|
||||
*channel = 1;
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
int32 channel_dim = value_tensor.dims() - 1;
|
||||
*channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
|
||||
for (int32 i = 0; i < channel_dim; i++) {
|
||||
*batch *= static_cast<int32>(value_tensor.dim_size(i));
|
||||
}
|
||||
} else if (data_format == FORMAT_NCHW) {
|
||||
int32 channel_dim = value_tensor.dims() - 3;
|
||||
int32 height_dim = value_tensor.dims() - 2;
|
||||
int32 width_dim = value_tensor.dims() - 1;
|
||||
*channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
|
||||
*height = static_cast<int32>(value_tensor.dim_size(height_dim));
|
||||
*width = static_cast<int32>(value_tensor.dim_size(width_dim));
|
||||
for (int32 i = 0; i < channel_dim; i++) {
|
||||
*batch *= static_cast<int32>(value_tensor.dim_size(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct AccumulatorType {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
// float is faster on the CPU than half, and also more precise,
|
||||
// so use float for the temporary accumulators.
|
||||
template <>
|
||||
struct AccumulatorType<Eigen::half> {
|
||||
typedef float type;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T>
|
||||
class BiasOp : public BinaryOp<T> {
|
||||
public:
|
||||
@ -50,9 +92,6 @@ class BiasOp : public BinaryOp<T> {
|
||||
} else {
|
||||
data_format_ = FORMAT_NHWC;
|
||||
}
|
||||
OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument(context->device()->name() +
|
||||
" BiasOp only supports NHWC."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -65,9 +104,21 @@ class BiasOp : public BinaryOp<T> {
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
|
||||
errors::InvalidArgument("Biases must be 1D: ",
|
||||
bias.shape().DebugString()));
|
||||
const auto last_dim = input.shape().dims() - 1;
|
||||
|
||||
// Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
|
||||
size_t channel_dim;
|
||||
if (data_format_ == FORMAT_NCHW) {
|
||||
OP_REQUIRES(context, input.dims() == 4,
|
||||
errors::InvalidArgument(
|
||||
"NCHW format supports only 4D input tensor."));
|
||||
channel_dim = 1;
|
||||
}
|
||||
else
|
||||
channel_dim = input.shape().dims() - 1; // End of code by intel_tf.
|
||||
|
||||
OP_REQUIRES(
|
||||
context, bias.shape().dim_size(0) == input.shape().dim_size(last_dim),
|
||||
context,
|
||||
bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
|
||||
errors::InvalidArgument(
|
||||
"Must provide as many biases as the last dimension "
|
||||
"of the input tensor: ",
|
||||
@ -78,6 +129,19 @@ class BiasOp : public BinaryOp<T> {
|
||||
{0}, 0, input.shape(), &output));
|
||||
if (input.NumElements() == 0) return;
|
||||
|
||||
// Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
|
||||
if (data_format_ == FORMAT_NCHW) {
|
||||
int32 batch, height, width, channel;
|
||||
GetBiasValueDims(input, data_format_, &batch, &height, &width,
|
||||
&channel);
|
||||
Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
|
||||
Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
|
||||
const Device& d = context->eigen_device<Device>();
|
||||
output->tensor<T, 4>().device(d) = input.tensor<T, 4>() +
|
||||
bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
|
||||
return;
|
||||
} // End of code by intel_tf.
|
||||
|
||||
switch (input.shape().dims()) {
|
||||
case 2:
|
||||
Compute<2>(context, input, bias, output);
|
||||
@ -137,48 +201,6 @@ REGISTER_KERNEL(double);
|
||||
#undef REGISTER_KERNEL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
namespace {
|
||||
|
||||
void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
|
||||
int32* batch, int32* height, int32* width,
|
||||
int32* channel) {
|
||||
*batch = 1;
|
||||
*width = 1;
|
||||
*height = 1;
|
||||
*channel = 1;
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
int32 channel_dim = value_tensor.dims() - 1;
|
||||
*channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
|
||||
for (int32 i = 0; i < channel_dim; i++) {
|
||||
*batch *= static_cast<int32>(value_tensor.dim_size(i));
|
||||
}
|
||||
} else if (data_format == FORMAT_NCHW) {
|
||||
int32 channel_dim = value_tensor.dims() - 3;
|
||||
int32 height_dim = value_tensor.dims() - 2;
|
||||
int32 width_dim = value_tensor.dims() - 1;
|
||||
*channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
|
||||
*height = static_cast<int32>(value_tensor.dim_size(height_dim));
|
||||
*width = static_cast<int32>(value_tensor.dim_size(width_dim));
|
||||
for (int32 i = 0; i < channel_dim; i++) {
|
||||
*batch *= static_cast<int32>(value_tensor.dim_size(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct AccumulatorType {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
// float is faster on the CPU than half, and also more precise,
|
||||
// so use float for the temporary accumulators.
|
||||
template <>
|
||||
struct AccumulatorType<Eigen::half> {
|
||||
typedef float type;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T>
|
||||
class BiasGradOp : public OpKernel {
|
||||
public:
|
||||
@ -190,9 +212,6 @@ class BiasGradOp : public OpKernel {
|
||||
} else {
|
||||
data_format_ = FORMAT_NHWC;
|
||||
}
|
||||
OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument(context->device()->name() +
|
||||
" BiasGradOp only supports NHWC."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -222,18 +241,40 @@ class BiasGradOp : public OpKernel {
|
||||
// Eigen often crashes by design on empty tensors, but setZero is safe
|
||||
output->template flat<T>().setZero();
|
||||
} else {
|
||||
Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
|
||||
// Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
|
||||
if (data_format_ == FORMAT_NCHW) {
|
||||
OP_REQUIRES(context, output_backprop.dims() == 4,
|
||||
errors::InvalidArgument(
|
||||
"NCHW format supports only 4D input/output tensor."));
|
||||
Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
|
||||
#ifdef EIGEN_HAS_INDEX_LIST
|
||||
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
|
||||
using idx0 = Eigen::type2index<0>;
|
||||
using idx2 = Eigen::type2index<2>;
|
||||
using idx3 = Eigen::type2index<3>;
|
||||
Eigen::IndexList<idx0, idx2, idx3 > reduction_axes;
|
||||
#else
|
||||
Eigen::array<int, 1> reduction_axis = {0};
|
||||
Eigen::array<int, 3> reduction_axes = {0, 2, 3};
|
||||
#endif
|
||||
output->template flat<T>().device(context->eigen_device<Device>()) =
|
||||
output_backprop.flat<T>()
|
||||
.template cast<typename AccumulatorType<T>::type>()
|
||||
.reshape(two_dims)
|
||||
.sum(reduction_axis)
|
||||
.template cast<T>();
|
||||
output->template flat<T>().device(context->eigen_device<Device>()) =
|
||||
output_backprop.flat<T>()
|
||||
.template cast<typename AccumulatorType<T>::type>()
|
||||
.reshape(four_dims)
|
||||
.sum(reduction_axes)
|
||||
.template cast<T>(); // End of code by intel_tf.
|
||||
} else {
|
||||
Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
|
||||
#ifdef EIGEN_HAS_INDEX_LIST
|
||||
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
|
||||
#else
|
||||
Eigen::array<int, 1> reduction_axis = {0};
|
||||
#endif
|
||||
output->template flat<T>().device(context->eigen_device<Device>()) =
|
||||
output_backprop.flat<T>()
|
||||
.template cast<typename AccumulatorType<T>::type>()
|
||||
.reshape(two_dims)
|
||||
.sum(reduction_axis)
|
||||
.template cast<T>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user