Add FakeQuant ops and kernels for use with quantized training.

Change: 137081182
This commit is contained in:
A. Unique TensorFlower 2016-10-24 13:35:38 -08:00 committed by TensorFlower Gardener
parent 4a465522c1
commit 9fb15ea28b
8 changed files with 2061 additions and 1 deletions

View File

@ -521,6 +521,7 @@ cc_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:ctc_ops",
"//tensorflow/core/kernels:data_flow",
"//tensorflow/core/kernels:fake_quant_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io",

View File

@ -563,6 +563,24 @@ tf_cc_test(
],
)
tf_cc_test(
name = "fake_quant_ops_test",
size = "small",
srcs = ["fake_quant_ops_test.cc"],
deps = [
":fake_quant_ops",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "fused_batch_norm_op_test",
size = "small",
@ -1710,6 +1728,22 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "fake_quant_ops",
srcs = ["fake_quant_ops.cc"],
hdrs = ["fake_quant_ops_functor.h"],
gpu_srcs = [
"fake_quant_ops_gpu.cu.cc",
"fake_quant_ops_functor.h",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
alwayslink = 1,
)
tf_kernel_library(
name = "fused_batch_norm_util",
gpu_srcs = [

View File

@ -0,0 +1,580 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#define FAKE_QUANT_NO_DEBUG
#include "tensorflow/core/kernels/fake_quant_ops_functor.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/protobuf.h"
using tensorflow::BinaryElementWiseOp;
using tensorflow::DEVICE_CPU;
#if GOOGLE_CUDA
using tensorflow::DEVICE_GPU;
#endif
using tensorflow::DT_BOOL;
using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext;
using tensorflow::PersistentTensor;
using tensorflow::Tensor;
using tensorflow::TensorShape;
using tensorflow::TTypes; // NOLINT This is needed in CUDA mode, do not remove.
using tensorflow::UnaryElementWiseOp;
using tensorflow::errors::InvalidArgument;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
// -----------------------------------------------------------------------------
// Implementation of FakeQuantWithMinMaxArgsOp, see its documentation in
// core/ops/array_ops.cc.
template <typename Device>
class FakeQuantWithMinMaxArgsOp
: public UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> {
public:
typedef UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> Base;
explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* context)
: Base::UnaryElementWiseOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
OP_REQUIRES(context, min_ < max_,
InvalidArgument("min has to be smaller than max, was: ", min_,
" >= ", max_));
}
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
FakeQuantWithMinMaxArgsFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_,
output->flat<float>());
}
private:
float min_;
float max_;
};
// Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in
// core/ops/array_ops.cc.
template <typename Device>
class FakeQuantWithMinMaxArgsGradientOp
: public BinaryElementWiseOp<float,
FakeQuantWithMinMaxArgsGradientOp<Device>> {
public:
typedef BinaryElementWiseOp<float, FakeQuantWithMinMaxArgsGradientOp<Device>>
Base;
explicit FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction* context)
: Base::BinaryElementWiseOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
OP_REQUIRES(context, min_ < max_,
InvalidArgument("min has to be smaller than max, was: ", min_,
" >= ", max_));
}
template <int NDIMS>
void Operate(OpKernelContext* context, const Tensor& gradient,
const Tensor& input, Tensor* output) {
OperateNoTemplate(context, gradient, input, output);
}
void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient,
const Tensor& input, Tensor* output) {
OP_REQUIRES(context, input.IsSameSize(gradient),
InvalidArgument("gradient and input must be the same size"));
FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), gradient.flat<float>(),
input.flat<float>(), min_, max_, output->flat<float>());
}
private:
float min_;
float max_;
};
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU),
FakeQuantWithMinMaxArgsOp<CPUDevice>);
REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU),
FakeQuantWithMinMaxArgsGradientOp<CPUDevice>);
#if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;
// Forward declarations for functor specializations for GPU.
template <>
void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()(
const GPUDevice& d,
typename TTypes<float>::ConstFlat inputs,
const float min, const float max,
typename TTypes<float>::Flat outputs);
extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
FakeQuantWithMinMaxArgsOp<GPUDevice>);
template <>
void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d,
typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs,
const float min, const float max,
typename TTypes<float>::Flat backprops);
REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
FakeQuantWithMinMaxArgsGradientOp<GPUDevice>);
#endif // GOOGLE_CUDA
// -----------------------------------------------------------------------------
// Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in
// core/ops/array_ops.cc.
template <typename Device>
class FakeQuantWithMinMaxVarsOp : public OpKernel {
public:
explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* context)
: OpKernel::OpKernel(context) {
#ifndef FAKE_QUANT_NO_DEBUG
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_BOOL, {},
&check_min_max_handle_,
nullptr));
#endif
}
void Compute(OpKernelContext* context) override {
CHECK_EQ(3, context->num_inputs());
const Tensor& input = context->input(0);
const Tensor& min = context->input(1);
const Tensor& max = context->input(2);
#ifndef FAKE_QUANT_NO_DEBUG
Tensor* check_min_max = check_min_max_handle_.AccessTensor(context);
#endif
Tensor* output;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
FakeQuantWithMinMaxVarsFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.flat<float>(),
min.scalar<float>(), max.scalar<float>(),
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max->scalar<bool>(),
#endif
output->flat<float>());
}
private:
#ifndef FAKE_QUANT_NO_DEBUG
PersistentTensor check_min_max_handle_;
#endif
};
// Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in
// core/ops/array_ops.cc.
template <typename Device>
class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
public:
explicit FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction* context)
: OpKernel::OpKernel(context) {
#ifndef FAKE_QUANT_NO_DEBUG
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_BOOL, {},
&check_min_max_handle_,
nullptr));
#endif
}
void Compute(OpKernelContext* context) override {
CHECK_EQ(4, context->num_inputs());
const Tensor& gradient = context->input(0);
const Tensor& input = context->input(1);
OP_REQUIRES(context, input.IsSameSize(gradient),
InvalidArgument("gradient and input must be the same size"));
const Tensor& min = context->input(2);
const Tensor& max = context->input(3);
#ifndef FAKE_QUANT_NO_DEBUG
Tensor* check_min_max = check_min_max_handle_.AccessTensor(context);
#endif
Tensor* grad_wrt_input;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &grad_wrt_input));
TensorShape scalar_shape;
Tensor* grad_wrt_min;
OP_REQUIRES_OK(context,
context->allocate_output(1, scalar_shape, &grad_wrt_min));
Tensor* grad_wrt_max;
OP_REQUIRES_OK(context,
context->allocate_output(2, scalar_shape, &grad_wrt_max));
FakeQuantWithMinMaxVarsGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), gradient.flat<float>(),
input.flat<float>(), min.scalar<float>(), max.scalar<float>(),
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max->scalar<bool>(),
#endif
grad_wrt_input->flat<float>(), grad_wrt_min->scalar<float>(),
grad_wrt_max->scalar<float>());
}
private:
#ifndef FAKE_QUANT_NO_DEBUG
PersistentTensor check_min_max_handle_;
#endif
};
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU),
FakeQuantWithMinMaxVarsOp<CPUDevice>);
REGISTER_KERNEL_BUILDER(
Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU),
FakeQuantWithMinMaxVarsGradientOp<CPUDevice>);
#if GOOGLE_CUDA
template <>
void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
const GPUDevice& d,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstScalar min,
typename TTypes<float>::ConstScalar max,
#ifndef FAKE_QUANT_NO_DEBUG
typename TTypes<bool>::Scalar check_min_max,
#endif
typename TTypes<float>::Flat output);
extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars")
.Device(DEVICE_GPU)
.HostMemory("min")
.HostMemory("max"),
FakeQuantWithMinMaxVarsOp<GPUDevice>);
template <>
void FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d,
typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstScalar min,
typename TTypes<float>::ConstScalar max,
#ifndef FAKE_QUANT_NO_DEBUG
typename TTypes<bool>::Scalar check_min_max,
#endif
typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Scalar backprop_wrt_min,
typename TTypes<float>::Scalar backprop_wrt_max);
extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient")
.Device(DEVICE_GPU)
.HostMemory("min")
.HostMemory("max"),
FakeQuantWithMinMaxVarsGradientOp<GPUDevice>);
#endif // GOOGLE_CUDA
// -----------------------------------------------------------------------------
// Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation
// in core/ops/array_ops.cc.
template <typename Device>
class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
public:
explicit FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction* context)
: OpKernel::OpKernel(context) {
#ifndef FAKE_QUANT_NO_DEBUG
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_BOOL, {},
&check_min_max_handle_,
nullptr));
#endif
}
void Compute(OpKernelContext* context) override {
CHECK_EQ(3, context->num_inputs());
const Tensor& input = context->input(0);
const int depth = input.dim_size(input.dims() - 1); // last dimension size.
const Tensor& min = context->input(1);
OP_REQUIRES(context, min.dim_size(0) == depth,
InvalidArgument("min has incorrect size, expected ", depth,
" was ", min.dim_size(0)));
const Tensor& max = context->input(2);
OP_REQUIRES(context, max.dim_size(0) == depth,
InvalidArgument("max has incorrect size, expected ", depth,
" was ", max.dim_size(0)));
#ifndef FAKE_QUANT_NO_DEBUG
Tensor* check_min_max = check_min_max_handle_.AccessTensor(context);
#endif
Tensor* output;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
switch (input.dims()) {
case 4: {
FakeQuant4WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), input.dim_size(2), input.dim_size(3),
input.flat<float>(), min.vec<float>(), max.vec<float>(),
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max->scalar<bool>(),
#endif
output->flat<float>());
break;
}
case 2: {
FakeQuant2WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(),
input.dim_size(0), input.dim_size(1),
input.flat<float>(), min.vec<float>(), max.vec<float>(),
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max->scalar<bool>(),
#endif
output->flat<float>());
break;
}
case 1: {
FakeQuant1WithMinMaxVarsPerChannelFunctor<Device> functor;
functor(context->eigen_device<Device>(),
input.vec<float>(), min.vec<float>(), max.vec<float>(),
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max->scalar<bool>(),
#endif
output->vec<float>());
break;
}
default:
context->SetStatus(InvalidArgument("Only inputs of dimensions 1, 2 or "
"4 supported, was: ", input.dims()));
break;
}
}
private:
#ifndef FAKE_QUANT_NO_DEBUG
PersistentTensor check_min_max_handle_;
#endif
};
// Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its
// documentation in core/ops/array_ops.cc.
template <typename Device>
class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
public:
explicit FakeQuantWithMinMaxVarsPerChannelGradientOp(
OpKernelConstruction* context) : OpKernel::OpKernel(context) {
#ifndef FAKE_QUANT_NO_DEBUG
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_BOOL, {},
&check_min_max_handle_,
nullptr));
#endif
}
void Compute(OpKernelContext* context) override {
CHECK_EQ(4, context->num_inputs());
const Tensor& gradient = context->input(0);
const Tensor& input = context->input(1);
OP_REQUIRES(context, input.IsSameSize(gradient),
InvalidArgument("gradient and input must be the same size"));
const int depth = input.dim_size(input.dims() - 1); // last dimension size.
const Tensor& min = context->input(2);
OP_REQUIRES(context, min.dim_size(0) == depth,
InvalidArgument("min has incorrect size, expected ", depth,
" was ", min.dim_size(0)));
const Tensor& max = context->input(3);
OP_REQUIRES(context, max.dim_size(0) == depth,
InvalidArgument("max has incorrect size, expected ", depth,
" was ", max.dim_size(0)));
#ifndef FAKE_QUANT_NO_DEBUG
Tensor* check_min_max = check_min_max_handle_.AccessTensor(context);
#endif
Tensor* grad_wrt_input;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &grad_wrt_input));
TensorShape min_max_shape({input.dim_size(input.dims() - 1)});
Tensor* grad_wrt_min;
OP_REQUIRES_OK(context,
context->allocate_output(1, min_max_shape, &grad_wrt_min));
Tensor* grad_wrt_max;
OP_REQUIRES_OK(context,
context->allocate_output(2, min_max_shape, &grad_wrt_max));
switch (input.dims()) {
case 4: {
FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(), input.dim_size(0),
input.dim_size(1), input.dim_size(2), input.dim_size(3),
gradient.flat<float>(), input.flat<float>(),
min.vec<float>(), max.vec<float>(),
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max->scalar<bool>(),
#endif
grad_wrt_input->flat<float>(),
grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
break;
}
case 2: {
FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(),
input.dim_size(0), input.dim_size(1),
gradient.flat<float>(), input.flat<float>(),
min.vec<float>(), max.vec<float>(),
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max->scalar<bool>(),
#endif
grad_wrt_input->flat<float>(),
grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
break;
}
case 1: {
FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
functor(context->eigen_device<Device>(),
gradient.vec<float>(), input.vec<float>(),
min.vec<float>(), max.vec<float>(),
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max->scalar<bool>(),
#endif
grad_wrt_input->vec<float>(),
grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
break;
}
default:
context->SetStatus(InvalidArgument("Only inputs of dimensions 1, 2 or "
"4 supported, was: ", input.dims()));
break;
}
}
private:
#ifndef FAKE_QUANT_NO_DEBUG
PersistentTensor check_min_max_handle_;
#endif
};
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
.Device(DEVICE_CPU),
FakeQuantWithMinMaxVarsPerChannelOp<CPUDevice>);
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
.Device(DEVICE_CPU),
FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>);
#if GOOGLE_CUDA
template <>
void FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const GPUDevice& d,
typename TTypes<float>::ConstVec inputs,
typename TTypes<float>::ConstVec min,
typename TTypes<float>::ConstVec max,
#ifndef FAKE_QUANT_NO_DEBUG
typename TTypes<bool>::Scalar check_min_max,
#endif
typename TTypes<float>::Vec outputs);
extern template struct FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>;
template <>
void FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const GPUDevice& d, const Index batch_size, const Index depth,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstFlat min,
typename TTypes<float>::ConstFlat max,
#ifndef FAKE_QUANT_NO_DEBUG
typename TTypes<bool>::Scalar check_min_max,
#endif
typename TTypes<float>::Flat outputs);
extern template struct FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>;
template <>
void FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
const GPUDevice& d, const Index batch_size, const Index height,
const Index width, const Index depth,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstFlat min,
typename TTypes<float>::ConstFlat max,
#ifndef FAKE_QUANT_NO_DEBUG
typename TTypes<bool>::Scalar check_min_max,
#endif
typename TTypes<float>::Flat outputs);
extern template struct FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
.Device(DEVICE_GPU)
.HostMemory("min")
.HostMemory("max"),
FakeQuantWithMinMaxVarsPerChannelOp<GPUDevice>);
template <>
void FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d,
typename TTypes<float>::ConstVec gradients,
typename TTypes<float>::ConstVec inputs,
typename TTypes<float>::ConstVec min,
typename TTypes<float>::ConstVec max,
#ifndef FAKE_QUANT_NO_DEBUG
typename TTypes<bool>::Scalar check_min_max,
#endif
typename TTypes<float>::Vec backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max);
extern template struct
FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
template <>
void FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d, const Index batch_size, const Index depth,
typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstVec min,
typename TTypes<float>::ConstVec max,
#ifndef FAKE_QUANT_NO_DEBUG
typename TTypes<bool>::Scalar check_min_max,
#endif
typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max);
extern template struct
FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
template <>
void FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
const GPUDevice& d, const Index batch_size, const Index height,
const Index width, const Index depth,
typename TTypes<float>::ConstFlat gradients,
typename TTypes<float>::ConstFlat inputs,
typename TTypes<float>::ConstVec min,
typename TTypes<float>::ConstVec max,
#ifndef FAKE_QUANT_NO_DEBUG
typename TTypes<bool>::Scalar check_min_max,
#endif
typename TTypes<float>::Flat backprops_wrt_input,
typename TTypes<float>::Vec backprop_wrt_min,
typename TTypes<float>::Vec backprop_wrt_max);
extern template struct
FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
.Device(DEVICE_GPU)
.HostMemory("min")
.HostMemory("max"),
FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>);
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -0,0 +1,434 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
#include <tuple>
#define EIGEN_STACK_ALLOCATION_LIMIT 0
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
static constexpr int kSteps = 255;
static constexpr float kStepsFloat = static_cast<float>(kSteps);
// Gymnastics with nudged zero point is to ensure that real zero maps to
// an integer, which is required for e.g. zero-padding in convolutional layers.
// Returns (nudged_min, nudged_max, nudged_scale).
template <typename Device>
std::tuple<float, float, float> Nudge(const float min, const float max) {
const float scale = (max - min) / (kStepsFloat - 0.0f);
const float zero_point_from_min = 0.0f - min / scale;
const uint8 nudged_zero_point = [zero_point_from_min] {
if (zero_point_from_min < 0.0f) {
return static_cast<uint8>(0);
} else if (zero_point_from_min > kStepsFloat) {
return static_cast<uint8>(kSteps);
} else {
return static_cast<uint8>(std::round(zero_point_from_min));
}
}();
const float nudged_min = (0.0f - nudged_zero_point) * scale;
const float nudged_max = (kStepsFloat - nudged_zero_point) * scale;
return std::make_tuple(nudged_min, nudged_max, scale);
}
template<typename T> using ConstScalar =
typename tensorflow::TTypes<T>::ConstScalar;
template<typename T> using Scalar = typename tensorflow::TTypes<T>::Scalar;
template<typename T> using ConstVec = typename tensorflow::TTypes<T>::ConstVec;
template<typename T> using Vec = typename tensorflow::TTypes<T>::Vec;
template<typename T> using ConstFlat =
typename tensorflow::TTypes<T>::ConstFlat;
template<typename T> using Flat = typename tensorflow::TTypes<T>::Flat;
// Functor called by FakeQuantWithMinMaxArgsOp to do the work. Compiles both
// for CPU and GPU.
template <typename Device>
struct FakeQuantWithMinMaxArgsFunctor {
void operator()(const Device& d, ConstFlat<float> inputs,
const float min, const float max, Flat<float> outputs) {
eigen_assert(min <= 0.0f && "min should be <= 0.0");
eigen_assert(max >= 0.0f && "max should be >= 0.0");
eigen_assert(min < max && "min should be < max");
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) = Nudge<Device>(min, max);
const float inv_nudged_scale = 1.0f / nudged_scale;
auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
auto clamped_shifted = clamped - nudged_min;
outputs.device(d) = (clamped_shifted * inv_nudged_scale + 0.5f).floor() *
nudged_scale + nudged_min;
}
};
// Functor called by FakeQuantWithMinMaxArgsGradientOp to do the work. Compiles
// both for CPU and GPU.
template <typename Device>
struct FakeQuantWithMinMaxArgsGradientFunctor {
void operator()(const Device& d, ConstFlat<float> gradients,
ConstFlat<float> inputs, const float min, const float max,
Flat<float> backprops) {
eigen_assert(min <= 0.0f && "min should be <= 0.0");
eigen_assert(max >= 0.0f && "max should be >= 0.0");
eigen_assert(min < max && "min should be < max");
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) = Nudge<Device>(min, max);
auto between_nudged_min_max = (inputs >= nudged_min && inputs <= nudged_max)
.select(inputs.constant(1.0f), inputs.constant(0.0f));
backprops.device(d) = gradients * between_nudged_min_max;
}
};
// Functor called by FakeQuantWithMinMaxVarsOp to do the work. Compiles both
// for CPU and GPU.
template <typename Device>
struct FakeQuantWithMinMaxVarsFunctor {
void operator()(const Device& d, ConstFlat<float> inputs,
ConstScalar<float> min, ConstScalar<float> max,
#ifndef FAKE_QUANT_NO_DEBUG
Scalar<bool> check_min_max,
#endif
Flat<float> outputs) {
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max.device(d) = (min <= 0.0f).all();
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
check_min_max.device(d) = (max >= 0.0f).all();
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
check_min_max.device(d) = (min < max).all();
eigen_assert(check_min_max() && "min should be < max coeff-wise");
#endif
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) =
Nudge<Device>(min(), max());
const auto nudged_scale_repl = inputs.constant(nudged_scale);
const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
const auto clamped_shifted = clamped - nudged_min;
outputs.device(d) = (clamped_shifted / nudged_scale_repl + 0.5f).floor() *
nudged_scale_repl + nudged_min;
}
};
// Functor called by FakeQuantWithMinMaxVarsGradientOp to do the work. Compiles
// both for CPU and GPU.
template <typename Device>
struct FakeQuantWithMinMaxVarsGradientFunctor {
void operator()(const Device& d,
ConstFlat<float> gradients, ConstFlat<float> inputs,
ConstScalar<float> min, ConstScalar<float> max,
#ifndef FAKE_QUANT_NO_DEBUG
Scalar<bool> check_min_max,
#endif
Flat<float> backprops_wrt_input,
Scalar<float> backprop_wrt_min,
Scalar<float> backprop_wrt_max) {
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max.device(d) = (min <= 0.0f).all();
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
check_min_max.device(d) = (max >= 0.0f).all();
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
check_min_max.device(d) = (min < max).all();
eigen_assert(check_min_max() && "min should be < max coeff-wise");
#endif
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) =
Nudge<Device>(min(), max());
const auto between_min_max = (inputs >= nudged_min && inputs <= nudged_max)
.select(inputs.constant(1.0f), inputs.constant(0.0f));
backprops_wrt_input.device(d) = gradients * between_min_max;
const auto below_min = (inputs < nudged_min)
.select(inputs.constant(1.0f), inputs.constant(0.0f));
backprop_wrt_min.device(d) = (gradients * below_min).sum();
const auto above_max = (inputs > nudged_max)
.select(inputs.constant(1.0f), inputs.constant(0.0f));
backprop_wrt_max.device(d) = (gradients * above_max).sum();
}
};
using Index = typename tensorflow::TTypes<float>::ConstTensor::Index;
// Functor called by FakeQuantWithMinMaxVarsPerChannelOp to do the work.
// Compiles both for CPU and GPU.
//
// Already verified: inputs, outputs, min, max are of shape [d].
template <typename Device>
struct FakeQuant1WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, ConstVec<float> inputs,
ConstVec<float> min, ConstVec<float> max,
#ifndef FAKE_QUANT_NO_DEBUG
Scalar<bool> check_min_max,
#endif
Vec<float> outputs) {
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max.device(d) = (min <= 0.0f).all();
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
check_min_max.device(d) = (max >= 0.0f).all();
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
check_min_max.device(d) = (min < max).all();
eigen_assert(check_min_max() && "min should be < max coeff-wise");
#endif
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) =
Nudge<Device>(min(i), max(i));
const float clamped =
std::max(std::min(inputs(i), nudged_max), nudged_min);
const float clamped_shifted = clamped - nudged_min;
outputs(i) = std::round(clamped_shifted / nudged_scale) * nudged_scale +
nudged_min;
}
}
};
// Already verified: inputs, outputs are of shape [b, d], min, max are of shape
// [d].
template <typename Device>
struct FakeQuant2WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, const Index batch_size, const Index depth,
ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max,
#ifndef FAKE_QUANT_NO_DEBUG
Scalar<bool> check_min_max,
#endif
Flat<float> outputs) {
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max.device(d) = (min <= 0.0f).all();
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
check_min_max.device(d) = (max >= 0.0f).all();
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
check_min_max.device(d) = (min < max).all();
eigen_assert(check_min_max() && "min should be < max coeff-wise");
#endif
Eigen::DSizes<Index, 2> restored(batch_size, depth);
const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) =
Nudge<Device>(min(i), max(i));
const auto clamped = inputs_restored.chip<1>(i)
.cwiseMin(nudged_max).cwiseMax(nudged_min);
const auto clamped_shifted = clamped - nudged_min;
outputs.reshape(restored).chip<1>(i).device(d) =
(clamped_shifted / nudged_scale + 0.5f).floor() * nudged_scale +
nudged_min;
}
}
};
// Already verified: inputs, outputs are of shape [b, h, w, d], min, max are
// of shape [d].
template <typename Device>
struct FakeQuant4WithMinMaxVarsPerChannelFunctor {
void operator()(const Device& d, const Index batch_size, const Index height,
const Index width, const Index depth,
ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max,
#ifndef FAKE_QUANT_NO_DEBUG
Scalar<bool> check_min_max,
#endif
Flat<float> outputs) {
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max.device(d) = (min <= 0.0f).all();
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
check_min_max.device(d) = (max >= 0.0f).all();
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
check_min_max.device(d) = (min < max).all();
eigen_assert(check_min_max() && "min should be < max coeff-wise");
#endif
Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth);
const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) =
Nudge<Device>(min(i), max(i));
const auto clamped = inputs_restored.chip<3>(i)
.cwiseMin(nudged_max).cwiseMax(nudged_min);
const auto clamped_shifted = clamped - nudged_min;
outputs.reshape(restored).chip<3>(i).device(d) =
(clamped_shifted / nudged_scale + 0.5f).floor() * nudged_scale +
nudged_min;
}
}
};
// Functor called by FakeQuantWithMinMaxVarsPerChannelGradientOp to do the work.
// Compiles both for CPU and GPU.
//
// Already verified: gradients, inputs, outputs, min, max, backprops_wrt_input,
// backprop_wrt_min, backprop_wrt_max are of shape [d].
template <typename Device>
struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d,
ConstVec<float> gradients, ConstVec<float> inputs,
ConstVec<float> min, ConstVec<float> max,
#ifndef FAKE_QUANT_NO_DEBUG
Scalar<bool> check_min_max,
#endif
Vec<float> backprops_wrt_input, Vec<float> backprop_wrt_min,
Vec<float> backprop_wrt_max) {
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max.device(d) = (min <= 0.0f).all();
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
check_min_max.device(d) = (max >= 0.0f).all();
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
check_min_max.device(d) = (min < max).all();
eigen_assert(check_min_max() && "min should be < max coeff-wise");
#endif
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) =
Nudge<Device>(min(i), max(i));
const bool between_min_max =
inputs(i) >= nudged_min && inputs(i) <= nudged_max;
backprops_wrt_input(i) = between_min_max ? gradients(i) : 0.0f;
const bool below_min = inputs(i) < nudged_min;
backprop_wrt_min(i) = below_min ? gradients(i) : 0.0f;
const bool above_max = inputs(i) > nudged_max;
backprop_wrt_max(i) = above_max ? gradients(i) : 0.0f;
}
}
};
// Already verified: gradients, inputs, backprops_wrt_input are of shape [b, d],
// min, max, backprop_wrt_min, backprop_wrt_max are of shape [d].
template <typename Device>
struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d, const Index batch_size, const Index depth,
ConstFlat<float> gradients, ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max,
#ifndef FAKE_QUANT_NO_DEBUG
Scalar<bool> check_min_max,
#endif
Flat<float> backprops_wrt_input,
Vec<float> backprop_wrt_min, Vec<float> backprop_wrt_max) {
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max.device(d) = (min <= 0.0f).all();
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
check_min_max.device(d) = (max >= 0.0f).all();
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
check_min_max.device(d) = (min < max).all();
eigen_assert(check_min_max() && "min should be < max coeff-wise");
#endif
Eigen::DSizes<Index, 2> restored(batch_size, depth);
const auto gradients_restored = gradients.reshape(restored);
const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) =
Nudge<Device>(min(i), max(i));
const auto gradients_chip = gradients_restored.chip<1>(i);
const auto inputs_chip = inputs_restored.chip<1>(i);
const auto between_min_max =
(inputs_chip >= nudged_min && inputs_chip <= nudged_max)
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
backprops_wrt_input.reshape(restored).chip<1>(i).device(d) =
gradients_chip * between_min_max;
const auto below_min = (inputs_chip < nudged_min)
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
Eigen::DSizes<Index, 1> reduce(0);
backprop_wrt_min.chip<0>(i).device(d) =
(gradients_chip * below_min).sum(reduce);
const auto above_max = (inputs_chip > nudged_max)
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
backprop_wrt_max.chip<0>(i).device(d) =
(gradients_chip * above_max).sum(reduce);
}
}
};
// Already verified: gradients, inputs, backprops_wrt_input are of shape
// [b, h, w, d], min, max, backprop_wrt_min, backprop_wrt_max are of shape [d].
template <typename Device>
struct FakeQuant4WithMinMaxVarsPerChannelGradientFunctor {
void operator()(const Device& d, const Index batch_size, const Index height,
const Index width, const Index depth,
ConstFlat<float> gradients, ConstFlat<float> inputs,
ConstVec<float> min, ConstVec<float> max,
#ifndef FAKE_QUANT_NO_DEBUG
Scalar<bool> check_min_max,
#endif
Flat<float> backprops_wrt_input,
Vec<float> backprop_wrt_min, Vec<float> backprop_wrt_max) {
#ifndef FAKE_QUANT_NO_DEBUG
check_min_max.device(d) = (min <= 0.0f).all();
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
check_min_max.device(d) = (max >= 0.0f).all();
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
check_min_max.device(d) = (min < max).all();
eigen_assert(check_min_max() && "min should be < max coeff-wise");
#endif
Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth);
const auto gradients_restored = gradients.reshape(restored);
const auto inputs_restored = inputs.reshape(restored);
for (Index i = 0; i < min.size(); ++i) {
float nudged_min, nudged_max, nudged_scale;
std::tie(nudged_min, nudged_max, nudged_scale) =
Nudge<Device>(min(i), max(i));
const auto gradients_chip = gradients_restored.chip<3>(i);
const auto inputs_chip = inputs_restored.chip<3>(i);
const auto between_min_max =
(inputs_chip >= nudged_min && inputs_chip <= nudged_max)
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
backprops_wrt_input.reshape(restored).chip<3>(i).device(d) =
gradients_chip * between_min_max;
const auto below_min = (inputs_chip < nudged_min)
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
Eigen::DSizes<Index, 3> reduce(0, 1, 2);
backprop_wrt_min.chip<0>(i).device(d) =
(gradients_chip * below_min).sum(reduce);
const auto above_max = (inputs_chip > nudged_max)
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
backprop_wrt_max.chip<0>(i).device(d) =
(gradients_chip * above_max).sum(reduce);
}
}
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_

View File

@ -0,0 +1,41 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define FAKE_QUANT_NO_DEBUG
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/fake_quant_ops_functor.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
// Just instantiate GPU functor implementations.
template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
template struct FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>;
template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
template struct FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>;
template struct FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>;
template struct FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>;
template struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
template struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
template struct FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,821 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/ops_testutil.h"
namespace tensorflow {
using tensorflow::AllocatorAttributes;
using tensorflow::DT_FLOAT;
using tensorflow::NodeDefBuilder;
using tensorflow::OpsTestBase;
using tensorflow::Tensor;
using tensorflow::TensorShape;
using tensorflow::test::ExpectClose;
using tensorflow::test::FillValues;
class QuantOpsTest : public OpsTestBase {
protected:
void AddRandomInput(const TensorShape& shape) {
CHECK_GT(input_types_.size(), inputs_.size())
<< "Adding more inputs than types; perhaps you need to call MakeOp";
Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
DT_FLOAT, shape);
input->flat<float>().setRandom();
tensors_.push_back(input);
bool is_ref = IsRefType(input_types_[inputs_.size()]);
if (is_ref) {
CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), DT_FLOAT);
inputs_.push_back({&lock_for_refs_, input});
} else {
CHECK_EQ(input_types_[inputs_.size()], DT_FLOAT);
inputs_.push_back({nullptr, input});
}
}
};
TEST_F(QuantOpsTest, WithArgsNoNudging) {
// Original quantization range: [-10 + 0 / 4, -10 + 255 / 4], scale: 1/4.
// Original zero point: 40, no nudging necessary.
// Expected quantized values: -10.0, -10.25, ..., 53.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs")
.Input(FakeInput(DT_FLOAT)) // inputs
.Attr("min", -10.0f)
.Attr("max", 53.75f)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected,
{-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithArgsNudgedZeroIs0) {
// Original quantization range: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.4, nudged to 0.
// Nudged range: [0.0; 63.75].
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs")
.Input(FakeInput(DT_FLOAT)) // inputs
.Attr("min", -0.1f)
.Attr("max", 63.65f)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.1f, 0.0f, 0.1f, 0.25f, 63.75f, 63.8f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected, {0.0f, 0.0f, 0.0f, 0.25f, 63.75f, 63.75f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithArgsNudgedZeroIs1) {
// Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged range: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs")
.Input(FakeInput(DT_FLOAT)) // inputs
.Attr("min", -0.125f)
.Attr("max", 63.625f)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected, {-0.25f, -0.25f, -0.25f, 0.0f, 63.5f, 63.5f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithArgsNudgedZeroIs255) {
// Original quantization range: [0.4 / 4 - 255 / 4, 0.4 / 4 + 0 / 4].
// Scale: 1/4, original zero point: 254.6, nudged to 255.
// Nudged range: [-63.75; 0.0].
// Expected quantized values: -63.75, -63.5, -63.25, ..., 0.0.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs")
.Input(FakeInput(DT_FLOAT)) // inputs
.Attr("min", -63.65f)
.Attr("max", 0.1f)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-63.8f, -63.75f, -63.7f, -63.5f, 0.0f, 0.1f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected, {-63.75f, -63.75f, -63.75f, -63.5f, 0.0f, 0.0f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithArgsGradient) {
// Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged range: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgsGradient")
.Input(FakeInput(DT_FLOAT)) // gradient
.Input(FakeInput(DT_FLOAT)) // inputs
.Attr("min", -0.125f)
.Attr("max", 63.625f)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Upstream gradients.
AddRandomInput(TensorShape({2, 3}));
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
auto input_flat = GetInput(0).flat<float>();
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected,
{0.0f, input_flat(1), input_flat(2),
input_flat(3), input_flat(4), 0.0f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsNoNudging) {
// Original quantization range: [-10 + 0 / 4, -10 + 255 / 4], scale: 1/4.
// Original zero point: 40, no nudging necessary.
// Expected quantized values: -10.0, -10.25, ..., 53.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f});
// Min.
AddInputFromArray<float>(TensorShape({}), {-10.0f});
// Max.
AddInputFromArray<float>(TensorShape({}), {53.75f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected,
{-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsNudgedZeroIs0) {
// Original quantization range: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.4, nudged to 0.
// Nudged range: [0.0; 63.75].
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.1f, 0.0f, 0.1f, 0.25f, 63.75f, 63.8f});
// Min.
AddInputFromArray<float>(TensorShape({}), {-0.1f});
// Max.
AddInputFromArray<float>(TensorShape({}), {63.65f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected,
{0.0f, 0.0f, 0.0f, 0.25f, 63.75f, 63.75f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsNudgedZeroIs1) {
// Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged range: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f});
// Min.
AddInputFromArray<float>(TensorShape({}), {-0.125f});
// Max.
AddInputFromArray<float>(TensorShape({}), {63.625f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected,
{-0.25f, -0.25f, -0.25f, 0.0f, 63.5f, 63.5f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsGradient) {
// Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged range: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsGradient")
.Input(FakeInput(DT_FLOAT)) // gradients
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Upstream gradients.
AddRandomInput(TensorShape({2, 3}));
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f});
// Min.
AddInputFromArray<float>(TensorShape({}), {-0.125f});
// Max.
AddInputFromArray<float>(TensorShape({}), {63.625f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output_bprop_wrt_input = GetOutput(0);
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3}));
auto in_flat = GetInput(0).flat<float>();
FillValues<float>(&expected_bprop_wrt_input,
{0.0f, in_flat(1),
in_flat(2), in_flat(3),
in_flat(4), 0.0f});
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
Tensor* output_bprop_wrt_min = GetOutput(1);
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({}));
expected_bprop_wrt_min.flat<float>()(0) = in_flat(0);
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
Tensor* output_bprop_wrt_max = GetOutput(2);
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({}));
expected_bprop_wrt_max.flat<float>()(0) = in_flat(5);
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim1NudgedZeroIs0) {
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.4, nudged to 0.
// Nudged ranges: [0.0; 63.75].
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({4}), {-0.1f, 0.0f, 63.75f, 63.8f});
// Min.
AddInputFromArray<float>(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f});
// Max.
AddInputFromArray<float>(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected, {0.0f, 0.0f, 63.75f, 63.75f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim1NudgedZeroIs1) {
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged ranges: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({4}), {-0.26f, -0.25f, -0.24f, 63.6f});
// Min.
AddInputFromArray<float>(TensorShape({4}),
{-0.125f, -0.125f, -0.125f, -0.125f});
// Max.
AddInputFromArray<float>(TensorShape({4}),
{63.625f, 63.625f, 63.625f, 63.625f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected, {-0.25f, -0.25f, -0.25f, 63.5f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim2NudgedZeroIs0) {
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.4, nudged to 0.
// Nudged ranges: [0.0; 63.75].
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.1f, 0.0f, 0.1f,
0.25f, 63.75f, 63.8f});
// Min.
AddInputFromArray<float>(TensorShape({3}), {-0.1f, -0.1f, -0.1f});
// Max.
AddInputFromArray<float>(TensorShape({3}), {63.65f, 63.65f, 63.65f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected, {0.0f, 0.0f, 0.0f,
0.25f, 63.75f, 63.75f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim2NudgedZeroIs1) {
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged ranges: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.26f, -0.25f, -0.24f,
0.0f, 63.5f, 63.6f});
// Min.
AddInputFromArray<float>(TensorShape({3}), {-0.125f, -0.125f, -0.125f});
// Max.
AddInputFromArray<float>(TensorShape({3}), {63.625f, 63.625f, 63.625f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
FillValues<float>(&expected, {-0.25f, -0.25f, -0.25f,
0.0f, 63.5f, 63.5f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim4NudgedZeroIs0) {
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.4, nudged to 0.
// Nudged ranges: [0.0; 63.75].
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
{-0.1f, 0.0f, 0.1f, 0.25f,
0.5f, 0.75f, 1.0f, 1.25f,
1.5f, 1.75f, 2.0f, 2.25f,
63.0f, 63.25f, 63.5f, 63.7f,
63.75f, 63.8f, 63.9f, 100.0f,
100.0f, 100.0f, 100.0f, 1000.0f});
// Min.
AddInputFromArray<float>(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f});
// Max.
AddInputFromArray<float>(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4}));
FillValues<float>(&expected,
{0.0f, 0.0f, 0.0f, 0.25f,
0.5f, 0.75f, 1.0f, 1.25f,
1.5f, 1.75f, 2.0f, 2.25f,
63.0f, 63.25f, 63.5f, 63.75f,
63.75f, 63.75f, 63.75f, 63.75f,
63.75f, 63.75f, 63.75f, 63.75f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim4NudgedZeroIs1) {
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged ranges: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Downstream inputs.
AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
{-0.3f, -0.25f, -0.2f, 0.0f,
0.25f, 0.5f, 0.75f, 1.0f,
1.25f, 1.5f, 1.75f, 2.0f,
63.0f, 63.25f, 63.4f, 63.5f,
63.6f, 63.7f, 100.0f, 100.0f,
100.0f, 100.0f, 100.0f, 1000.0f});
// Min.
AddInputFromArray<float>(TensorShape({4}),
{-0.125f, -0.125f, -0.125f, -0.125f});
// Max.
AddInputFromArray<float>(TensorShape({4}),
{63.625f, 63.625f, 63.625f, 63.625f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4}));
FillValues<float>(&expected,
{-0.25f, -0.25f, -0.25f, 0.0f,
0.25f, 0.5f, 0.75f, 1.0f,
1.25f, 1.5f, 1.75f, 2.0f,
63.0f, 63.25f, 63.5f, 63.5f,
63.5f, 63.5f, 63.5f, 63.5f,
63.5f, 63.5f, 63.5f, 63.5f});
ExpectClose(expected, *output);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim1GradientNudgedZeroIs0) {
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.4, nudged to 0.
// Nudged ranges: [0.0; 63.75].
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
.Input(FakeInput(DT_FLOAT)) // gradients
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Upstream gradients.
AddRandomInput(TensorShape({4}));
// Downstream inputs.
AddInputFromArray<float>(TensorShape({4}), {-0.1f, 0.0f, 63.75f, 63.8f});
// Min.
AddInputFromArray<float>(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f});
// Max.
AddInputFromArray<float>(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output_bprop_wrt_input = GetOutput(0);
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({4}));
auto grad_flat = GetInput(0).flat<float>();
FillValues<float>(&expected_bprop_wrt_input,
{0.0f, grad_flat(1), grad_flat(2), 0.0f});
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
Tensor* output_bprop_wrt_min = GetOutput(1);
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected_bprop_wrt_min,
{grad_flat(0), 0.0f, 0.0f, 0.0f});
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
Tensor* output_bprop_wrt_max = GetOutput(2);
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected_bprop_wrt_max,
{0.0f, 0.0f, 0.0f, grad_flat(3)});
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim1GradientNudgedZeroIs1) {
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged ranges: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
.Input(FakeInput(DT_FLOAT)) // gradients
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Upstream gradients.
AddRandomInput(TensorShape({4}));
// Downstream inputs.
AddInputFromArray<float>(TensorShape({4}), {-0.3f, -0.25f, 63.5f, 63.6f});
// Min.
AddInputFromArray<float>(TensorShape({4}),
{-0.125f, -0.125f, -0.125f, -0.125f});
// Max.
AddInputFromArray<float>(TensorShape({4}),
{63.625f, 63.625f, 63.625f, 63.625f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output_bprop_wrt_input = GetOutput(0);
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({4}));
auto grad_flat = GetInput(0).flat<float>();
FillValues<float>(&expected_bprop_wrt_input,
{0.0f, grad_flat(1), grad_flat(2), 0.0f});
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
Tensor* output_bprop_wrt_min = GetOutput(1);
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected_bprop_wrt_min,
{grad_flat(0), 0.0f, 0.0f, 0.0f});
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
Tensor* output_bprop_wrt_max = GetOutput(2);
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected_bprop_wrt_max,
{0.0f, 0.0f, 0.0f, grad_flat(3)});
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim2GradientNudgedZeroIs0) {
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.4, nudged to 0.
// Nudged ranges: [0.0; 63.75].
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
.Input(FakeInput(DT_FLOAT)) // gradients
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Upstream gradients.
AddRandomInput(TensorShape({2, 3}));
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.1f, 0.0f, 0.1f,
0.25f, 63.75f, 63.8f});
// Min.
AddInputFromArray<float>(TensorShape({3}), {-0.1f, -0.1f, -0.1f});
// Max.
AddInputFromArray<float>(TensorShape({3}), {63.65f, 63.65f, 63.65f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output_bprop_wrt_input = GetOutput(0);
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3}));
auto grad_flat = GetInput(0).flat<float>();
FillValues<float>(&expected_bprop_wrt_input,
{0.0f, grad_flat(1), grad_flat(2),
grad_flat(3), grad_flat(4), 0.0f});
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
Tensor* output_bprop_wrt_min = GetOutput(1);
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({3}));
FillValues<float>(&expected_bprop_wrt_min,
{grad_flat(0), 0.0f, 0.0f});
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
Tensor* output_bprop_wrt_max = GetOutput(2);
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({3}));
FillValues<float>(&expected_bprop_wrt_max,
{0.0f, 0.0f, grad_flat(5)});
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim2GradientNudgedZeroIs1) {
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged ranges: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
.Input(FakeInput(DT_FLOAT)) // gradients
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Upstream gradients.
AddRandomInput(TensorShape({2, 3}));
// Downstream inputs.
AddInputFromArray<float>(TensorShape({2, 3}),
{-0.3f, -0.25f, -0.2f,
0.0f, 63.5f, 63.6f});
// Min.
AddInputFromArray<float>(TensorShape({3}), {-0.125f, -0.125f, -0.125f});
// Max.
AddInputFromArray<float>(TensorShape({3}), {63.625f, 63.625f, 63.625f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output_bprop_wrt_input = GetOutput(0);
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3}));
auto grad_flat = GetInput(0).flat<float>();
FillValues<float>(&expected_bprop_wrt_input,
{0.0f, grad_flat(1), grad_flat(2),
grad_flat(3), grad_flat(4), 0.0f});
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
Tensor* output_bprop_wrt_min = GetOutput(1);
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({3}));
FillValues<float>(&expected_bprop_wrt_min,
{grad_flat(0), 0.0f, 0.0f});
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
Tensor* output_bprop_wrt_max = GetOutput(2);
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({3}));
FillValues<float>(&expected_bprop_wrt_max,
{0.0f, 0.0f, grad_flat(5)});
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim4GradientNudgedZeroIs0) {
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.4, nudged to 0.
// Nudged ranges: [0.0; 63.75].
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
.Input(FakeInput(DT_FLOAT)) // gradients
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Upstream gradients.
AddRandomInput(TensorShape({1, 2, 3, 4}));
// Downstream inputs.
AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
{-0.1f, 0.0f, 63.75f, 63.8f,
-0.1f, 0.0f, 63.75f, 63.8f,
-0.1f, 0.0f, 63.75f, 63.8f,
-0.1f, 0.0f, 63.75f, 63.8f,
-0.1f, 0.0f, 63.75f, 63.8f,
-0.1f, 0.0f, 63.75f, 63.8f});
// Min.
AddInputFromArray<float>(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f});
// Max.
AddInputFromArray<float>(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output_bprop_wrt_input = GetOutput(0);
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT,
TensorShape({1, 2, 3, 4}));
auto grad_flat = GetInput(0).flat<float>();
FillValues<float>(
&expected_bprop_wrt_input,
{0.0f, grad_flat(1), grad_flat(2), 0.0f,
0.0f, grad_flat(5), grad_flat(6), 0.0f,
0.0f, grad_flat(9), grad_flat(10), 0.0f,
0.0f, grad_flat(13), grad_flat(14), 0.0f,
0.0f, grad_flat(17), grad_flat(18), 0.0f,
0.0f, grad_flat(21), grad_flat(22), 0.0f});
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
Tensor* output_bprop_wrt_min = GetOutput(1);
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected_bprop_wrt_min,
{grad_flat(0) + grad_flat(4) + grad_flat(8) +
grad_flat(12) + grad_flat(16) + grad_flat(20),
0.0f, 0.0f, 0.0f});
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
Tensor* output_bprop_wrt_max = GetOutput(2);
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected_bprop_wrt_max,
{0.0f, 0.0f, 0.0f,
grad_flat(3) + grad_flat(7) + grad_flat(11) +
grad_flat(15) + grad_flat(19) + grad_flat(23)});
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
}
TEST_F(QuantOpsTest, WithVarsPerChannelDim4GradientNudgedZeroIs1) {
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
// Scale: 1/4, original zero point: 0.5, nudged to 1.
// Nudged ranges: [-0.25; 63.5].
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
.Input(FakeInput(DT_FLOAT)) // gradients
.Input(FakeInput(DT_FLOAT)) // inputs
.Input(FakeInput(DT_FLOAT)) // min
.Input(FakeInput(DT_FLOAT)) // max
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
// Upstream gradients.
AddRandomInput(TensorShape({1, 2, 3, 4}));
// Downstream inputs.
AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
{-0.3f, -0.25f, 63.5f, 63.6f,
-0.3f, -0.25f, 63.5f, 63.6f,
-0.3f, -0.25f, 63.5f, 63.6f,
-0.3f, -0.25f, 63.5f, 63.6f,
-0.3f, -0.25f, 63.5f, 63.6f,
-0.3f, -0.25f, 63.5f, 63.6f});
// Min.
AddInputFromArray<float>(TensorShape({4}),
{-0.125f, -0.125f, -0.125f, -0.125f});
// Max.
AddInputFromArray<float>(TensorShape({4}),
{63.625f, 63.625f, 63.625f, 63.625f});
// Tested code.
TF_ASSERT_OK(RunOpKernel());
Tensor* output_bprop_wrt_input = GetOutput(0);
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT,
TensorShape({1, 2, 3, 4}));
auto grad_flat = GetInput(0).flat<float>();
FillValues<float>(&expected_bprop_wrt_input,
{0.0f, grad_flat(1), grad_flat(2), 0.0f,
0.0f, grad_flat(5), grad_flat(6), 0.0f,
0.0f, grad_flat(9), grad_flat(10), 0.0f,
0.0f, grad_flat(13), grad_flat(14), 0.0f,
0.0f, grad_flat(17), grad_flat(18), 0.0f,
0.0f, grad_flat(21), grad_flat(22), 0.0f});
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
Tensor* output_bprop_wrt_min = GetOutput(1);
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected_bprop_wrt_min,
{grad_flat(0) + grad_flat(4) + grad_flat(8) +
grad_flat(12) + grad_flat(16) + grad_flat(20),
0.0f, 0.0f, 0.0f});
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
Tensor* output_bprop_wrt_max = GetOutput(2);
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4}));
FillValues<float>(&expected_bprop_wrt_max,
{0.0f, 0.0f, 0.0f,
grad_flat(3) + grad_flat(7) + grad_flat(11) +
grad_flat(15) + grad_flat(19) + grad_flat(23)});
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
}
} // namespace tensorflow

View File

@ -4383,6 +4383,117 @@ output_min: This value is copied from input_min.
output_max: This value is copied from input_max.
)Doc");
REGISTER_OP("FakeQuantWithMinMaxArgs")
.Attr("min: float = -6.0")
.Attr("max: float = 6.0")
.Input("inputs: float")
.Output("outputs: float")
.Doc(R"doc(
Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
Attributes [min; max] define the clamping range for the 'inputs' data. Op
divides this range into 255 steps (total of 256 values), then replaces each
'inputs' value with the closest of the quantized step values.
Quantization is called fake since the output is still in floating point.
)doc");
REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
.Attr("min: float = -6.0")
.Attr("max: float = 6.0")
.Input("gradients: float")
.Input("inputs: float")
.Output("backprops: float")
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxArgs operation.
gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation.
inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation.
backprops: Backpropagated gradients below the FakeQuantWithMinMaxArgs operation:
`gradients * (inputs >= min && inputs <= max)`.
)doc");
REGISTER_OP("FakeQuantWithMinMaxVars")
.Input("inputs: float")
.Input("min: float")
.Input("max: float")
.Output("outputs: float")
.Doc(R"doc(
Fake-quantize the 'inputs' tensor of type float and shape `[b, h, w, d]` via
global float scalars `min` and `max` to 'outputs' tensor of same shape as
`inputs`.
[min; max] is the clamping range for the 'inputs' data. Op divides this range
into 255 steps (total of 256 values), then replaces each 'inputs' value with the
closest of the quantized step values.
This operation has a gradient and thus allows for training `min` and `max` values.
)doc");
REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
.Input("gradients: float")
.Input("inputs: float")
.Input("min: float")
.Input("max: float")
.Output("backprops_wrt_input: float")
.Output("backprop_wrt_min: float")
.Output("backprop_wrt_max: float")
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxVars operation.
gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation.
inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation.
min, max: Quantization interval, scalar floats.
backprops_wrt_input: Backpropagated gradients w.r.t. inputs:
`gradients * (inputs >= min && inputs <= max)`.
backprop_wrt_min: Backpropagated gradients w.r.t. min parameter:
`sum(gradients * (inputs < min))`.
backprop_wrt_max: Backpropagated gradients w.r.t. max parameter:
`sum(gradients * (inputs > max))`.
)doc");
REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
.Input("inputs: float")
.Input("min: float")
.Input("max: float")
.Output("outputs: float")
.Doc(R"doc(
Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`
to 'outputs' tensor of same shape as `inputs`.
[min; max] is the clamping range for the 'inputs' data in the corresponding
depth channel. Op divides this range into 255 steps (total of 256 values), then
replaces each 'inputs' value with the closest of the quantized step values.
This operation has a gradient and thus allows for training `min` and `max` values.
)doc");
REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
.Input("gradients: float")
.Input("inputs: float")
.Input("min: float")
.Input("max: float")
.Output("backprops_wrt_input: float")
.Output("backprop_wrt_min: float")
.Output("backprop_wrt_max: float")
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.
gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation,
shape one of: `[d]`, `[b, d]`, `[b, h, w, d]`.
inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape
same as `gradients`.
min, max: Quantization interval, floats of shape `[d]`.
backprops_wrt_input: Backpropagated gradients w.r.t. inputs, shape same as
`inputs`:
`gradients * (inputs >= min && inputs <= max)`.
backprop_wrt_min: Backpropagated gradients w.r.t. min parameter, shape `[d]`:
`sum_per_d(gradients * (inputs < min))`.
backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`:
`sum_per_d(gradients * (inputs > max))`.
)doc");
// Deprecated op registrations:
// The following can be deleted after 10mar2017.

View File

@ -1905,7 +1905,6 @@ def _EditDistanceShape(op):
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[2, 5])
# The remaining ops do not change the shape of their inputs.
@ops.RegisterShape("Quantize")
@ops.RegisterShape("Dequantize")
def _QuantizeDequantizeShape(op):
@ -1914,6 +1913,45 @@ def _QuantizeDequantizeShape(op):
return common_shapes.unchanged_shape(op)
@ops.RegisterShape("FakeQuantWithMinMaxArgs")
def _FakeQuantWithMinMaxArgsShape(op):
"""Shape function for FakeQuantWithMinMaxArgs op: preserve the input shape."""
return [op.inputs[0].get_shape()]
@ops.RegisterGradient("FakeQuantWithMinMaxArgs")
def _FakeQuantWithMinMaxArgsGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxArgs op."""
return fake_quant_with_min_max_args_gradient(grad, op.inputs[0])
@ops.RegisterShape("FakeQuantWithMinMaxVars")
def _FakeQuantWithMinMaxVarsShape(op):
"""Shape function for FakeQuantWithMinMaxVars op: preserve the input shape."""
return [op.inputs[0].get_shape()]
@ops.RegisterGradient("FakeQuantWithMinMaxVars")
def _FakeQuantWithMinMaxVarsGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxVars op."""
return fake_quant_with_min_max_vars_gradient(grad, op.inputs[0], op.inputs[1],
op.inputs[2])
@ops.RegisterShape("FakeQuantWithMinMaxVarsPerChannel")
def _FakeQuantWithMinMaxVarsPerChannelShape(op):
"""Shape function for FakeQuantWithMinMaxVarsPerChannel op: input shape."""
return [op.inputs[0].get_shape()]
@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
"""Gradient for FakeQuantWithMinMaxVarsPerChannel op."""
return fake_quant_with_min_max_vars_per_channel_gradient(grad, op.inputs[0],
op.inputs[1],
op.inputs[2])
ops.RegisterShape("ExtractImagePatches")(common_shapes.call_cpp_shape_fn)