Add FakeQuant ops and kernels for use with quantized training.
Change: 137081182
This commit is contained in:
parent
4a465522c1
commit
9fb15ea28b
@ -521,6 +521,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:ctc_ops",
|
||||
"//tensorflow/core/kernels:data_flow",
|
||||
"//tensorflow/core/kernels:fake_quant_ops",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:image",
|
||||
"//tensorflow/core/kernels:io",
|
||||
|
@ -563,6 +563,24 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "fake_quant_ops_test",
|
||||
size = "small",
|
||||
srcs = ["fake_quant_ops_test.cc"],
|
||||
deps = [
|
||||
":fake_quant_ops",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "fused_batch_norm_op_test",
|
||||
size = "small",
|
||||
@ -1710,6 +1728,22 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "fake_quant_ops",
|
||||
srcs = ["fake_quant_ops.cc"],
|
||||
hdrs = ["fake_quant_ops_functor.h"],
|
||||
gpu_srcs = [
|
||||
"fake_quant_ops_gpu.cu.cc",
|
||||
"fake_quant_ops_functor.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "fused_batch_norm_util",
|
||||
gpu_srcs = [
|
||||
|
580
tensorflow/core/kernels/fake_quant_ops.cc
Normal file
580
tensorflow/core/kernels/fake_quant_ops.cc
Normal file
@ -0,0 +1,580 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#define FAKE_QUANT_NO_DEBUG
|
||||
|
||||
#include "tensorflow/core/kernels/fake_quant_ops_functor.h"
|
||||
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
using tensorflow::BinaryElementWiseOp;
|
||||
using tensorflow::DEVICE_CPU;
|
||||
#if GOOGLE_CUDA
|
||||
using tensorflow::DEVICE_GPU;
|
||||
#endif
|
||||
using tensorflow::DT_BOOL;
|
||||
using tensorflow::OpKernel;
|
||||
using tensorflow::OpKernelConstruction;
|
||||
using tensorflow::OpKernelContext;
|
||||
using tensorflow::PersistentTensor;
|
||||
using tensorflow::Tensor;
|
||||
using tensorflow::TensorShape;
|
||||
using tensorflow::TTypes; // NOLINT This is needed in CUDA mode, do not remove.
|
||||
using tensorflow::UnaryElementWiseOp;
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Implementation of FakeQuantWithMinMaxArgsOp, see its documentation in
|
||||
// core/ops/array_ops.cc.
|
||||
template <typename Device>
|
||||
class FakeQuantWithMinMaxArgsOp
|
||||
: public UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> {
|
||||
public:
|
||||
typedef UnaryElementWiseOp<float, FakeQuantWithMinMaxArgsOp<Device>> Base;
|
||||
explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* context)
|
||||
: Base::UnaryElementWiseOp(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
|
||||
OP_REQUIRES(context, min_ < max_,
|
||||
InvalidArgument("min has to be smaller than max, was: ", min_,
|
||||
" >= ", max_));
|
||||
}
|
||||
|
||||
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
|
||||
FakeQuantWithMinMaxArgsFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(), input.flat<float>(), min_, max_,
|
||||
output->flat<float>());
|
||||
}
|
||||
private:
|
||||
float min_;
|
||||
float max_;
|
||||
};
|
||||
|
||||
// Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in
|
||||
// core/ops/array_ops.cc.
|
||||
template <typename Device>
|
||||
class FakeQuantWithMinMaxArgsGradientOp
|
||||
: public BinaryElementWiseOp<float,
|
||||
FakeQuantWithMinMaxArgsGradientOp<Device>> {
|
||||
public:
|
||||
typedef BinaryElementWiseOp<float, FakeQuantWithMinMaxArgsGradientOp<Device>>
|
||||
Base;
|
||||
explicit FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction* context)
|
||||
: Base::BinaryElementWiseOp(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("min", &min_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("max", &max_));
|
||||
OP_REQUIRES(context, min_ < max_,
|
||||
InvalidArgument("min has to be smaller than max, was: ", min_,
|
||||
" >= ", max_));
|
||||
}
|
||||
|
||||
template <int NDIMS>
|
||||
void Operate(OpKernelContext* context, const Tensor& gradient,
|
||||
const Tensor& input, Tensor* output) {
|
||||
OperateNoTemplate(context, gradient, input, output);
|
||||
}
|
||||
|
||||
void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient,
|
||||
const Tensor& input, Tensor* output) {
|
||||
OP_REQUIRES(context, input.IsSameSize(gradient),
|
||||
InvalidArgument("gradient and input must be the same size"));
|
||||
FakeQuantWithMinMaxArgsGradientFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(), gradient.flat<float>(),
|
||||
input.flat<float>(), min_, max_, output->flat<float>());
|
||||
}
|
||||
private:
|
||||
float min_;
|
||||
float max_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU),
|
||||
FakeQuantWithMinMaxArgsOp<CPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU),
|
||||
FakeQuantWithMinMaxArgsGradientOp<CPUDevice>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Forward declarations for functor specializations for GPU.
|
||||
template <>
|
||||
void FakeQuantWithMinMaxArgsFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d,
|
||||
typename TTypes<float>::ConstFlat inputs,
|
||||
const float min, const float max,
|
||||
typename TTypes<float>::Flat outputs);
|
||||
extern template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU),
|
||||
FakeQuantWithMinMaxArgsOp<GPUDevice>);
|
||||
|
||||
template <>
|
||||
void FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d,
|
||||
typename TTypes<float>::ConstFlat gradients,
|
||||
typename TTypes<float>::ConstFlat inputs,
|
||||
const float min, const float max,
|
||||
typename TTypes<float>::Flat backprops);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU),
|
||||
FakeQuantWithMinMaxArgsGradientOp<GPUDevice>);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in
|
||||
// core/ops/array_ops.cc.
|
||||
template <typename Device>
|
||||
class FakeQuantWithMinMaxVarsOp : public OpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* context)
|
||||
: OpKernel::OpKernel(context) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(DT_BOOL, {},
|
||||
&check_min_max_handle_,
|
||||
nullptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
CHECK_EQ(3, context->num_inputs());
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& min = context->input(1);
|
||||
const Tensor& max = context->input(2);
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Tensor* check_min_max = check_min_max_handle_.AccessTensor(context);
|
||||
#endif
|
||||
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &output));
|
||||
|
||||
FakeQuantWithMinMaxVarsFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(), input.flat<float>(),
|
||||
min.scalar<float>(), max.scalar<float>(),
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max->scalar<bool>(),
|
||||
#endif
|
||||
output->flat<float>());
|
||||
}
|
||||
|
||||
private:
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
PersistentTensor check_min_max_handle_;
|
||||
#endif
|
||||
};
|
||||
|
||||
// Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in
|
||||
// core/ops/array_ops.cc.
|
||||
template <typename Device>
|
||||
class FakeQuantWithMinMaxVarsGradientOp : public OpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction* context)
|
||||
: OpKernel::OpKernel(context) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(DT_BOOL, {},
|
||||
&check_min_max_handle_,
|
||||
nullptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
CHECK_EQ(4, context->num_inputs());
|
||||
const Tensor& gradient = context->input(0);
|
||||
const Tensor& input = context->input(1);
|
||||
OP_REQUIRES(context, input.IsSameSize(gradient),
|
||||
InvalidArgument("gradient and input must be the same size"));
|
||||
const Tensor& min = context->input(2);
|
||||
const Tensor& max = context->input(3);
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Tensor* check_min_max = check_min_max_handle_.AccessTensor(context);
|
||||
#endif
|
||||
|
||||
Tensor* grad_wrt_input;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &grad_wrt_input));
|
||||
|
||||
TensorShape scalar_shape;
|
||||
Tensor* grad_wrt_min;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(1, scalar_shape, &grad_wrt_min));
|
||||
|
||||
Tensor* grad_wrt_max;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(2, scalar_shape, &grad_wrt_max));
|
||||
|
||||
FakeQuantWithMinMaxVarsGradientFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(), gradient.flat<float>(),
|
||||
input.flat<float>(), min.scalar<float>(), max.scalar<float>(),
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max->scalar<bool>(),
|
||||
#endif
|
||||
grad_wrt_input->flat<float>(), grad_wrt_min->scalar<float>(),
|
||||
grad_wrt_max->scalar<float>());
|
||||
}
|
||||
|
||||
private:
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
PersistentTensor check_min_max_handle_;
|
||||
#endif
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU),
|
||||
FakeQuantWithMinMaxVarsOp<CPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU),
|
||||
FakeQuantWithMinMaxVarsGradientOp<CPUDevice>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
template <>
|
||||
void FakeQuantWithMinMaxVarsFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d,
|
||||
typename TTypes<float>::ConstFlat inputs,
|
||||
typename TTypes<float>::ConstScalar min,
|
||||
typename TTypes<float>::ConstScalar max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
typename TTypes<bool>::Scalar check_min_max,
|
||||
#endif
|
||||
typename TTypes<float>::Flat output);
|
||||
extern template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("min")
|
||||
.HostMemory("max"),
|
||||
FakeQuantWithMinMaxVarsOp<GPUDevice>);
|
||||
|
||||
template <>
|
||||
void FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d,
|
||||
typename TTypes<float>::ConstFlat gradients,
|
||||
typename TTypes<float>::ConstFlat inputs,
|
||||
typename TTypes<float>::ConstScalar min,
|
||||
typename TTypes<float>::ConstScalar max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
typename TTypes<bool>::Scalar check_min_max,
|
||||
#endif
|
||||
typename TTypes<float>::Flat backprops_wrt_input,
|
||||
typename TTypes<float>::Scalar backprop_wrt_min,
|
||||
typename TTypes<float>::Scalar backprop_wrt_max);
|
||||
extern template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("min")
|
||||
.HostMemory("max"),
|
||||
FakeQuantWithMinMaxVarsGradientOp<GPUDevice>);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation
|
||||
// in core/ops/array_ops.cc.
|
||||
template <typename Device>
|
||||
class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction* context)
|
||||
: OpKernel::OpKernel(context) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(DT_BOOL, {},
|
||||
&check_min_max_handle_,
|
||||
nullptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
CHECK_EQ(3, context->num_inputs());
|
||||
const Tensor& input = context->input(0);
|
||||
const int depth = input.dim_size(input.dims() - 1); // last dimension size.
|
||||
const Tensor& min = context->input(1);
|
||||
OP_REQUIRES(context, min.dim_size(0) == depth,
|
||||
InvalidArgument("min has incorrect size, expected ", depth,
|
||||
" was ", min.dim_size(0)));
|
||||
const Tensor& max = context->input(2);
|
||||
OP_REQUIRES(context, max.dim_size(0) == depth,
|
||||
InvalidArgument("max has incorrect size, expected ", depth,
|
||||
" was ", max.dim_size(0)));
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Tensor* check_min_max = check_min_max_handle_.AccessTensor(context);
|
||||
#endif
|
||||
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &output));
|
||||
|
||||
switch (input.dims()) {
|
||||
case 4: {
|
||||
FakeQuant4WithMinMaxVarsPerChannelFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(), input.dim_size(0),
|
||||
input.dim_size(1), input.dim_size(2), input.dim_size(3),
|
||||
input.flat<float>(), min.vec<float>(), max.vec<float>(),
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max->scalar<bool>(),
|
||||
#endif
|
||||
output->flat<float>());
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
FakeQuant2WithMinMaxVarsPerChannelFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(),
|
||||
input.dim_size(0), input.dim_size(1),
|
||||
input.flat<float>(), min.vec<float>(), max.vec<float>(),
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max->scalar<bool>(),
|
||||
#endif
|
||||
output->flat<float>());
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
FakeQuant1WithMinMaxVarsPerChannelFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(),
|
||||
input.vec<float>(), min.vec<float>(), max.vec<float>(),
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max->scalar<bool>(),
|
||||
#endif
|
||||
output->vec<float>());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
context->SetStatus(InvalidArgument("Only inputs of dimensions 1, 2 or "
|
||||
"4 supported, was: ", input.dims()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
PersistentTensor check_min_max_handle_;
|
||||
#endif
|
||||
};
|
||||
|
||||
// Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its
|
||||
// documentation in core/ops/array_ops.cc.
|
||||
template <typename Device>
|
||||
class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxVarsPerChannelGradientOp(
|
||||
OpKernelConstruction* context) : OpKernel::OpKernel(context) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_persistent(DT_BOOL, {},
|
||||
&check_min_max_handle_,
|
||||
nullptr));
|
||||
#endif
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
CHECK_EQ(4, context->num_inputs());
|
||||
const Tensor& gradient = context->input(0);
|
||||
const Tensor& input = context->input(1);
|
||||
OP_REQUIRES(context, input.IsSameSize(gradient),
|
||||
InvalidArgument("gradient and input must be the same size"));
|
||||
const int depth = input.dim_size(input.dims() - 1); // last dimension size.
|
||||
const Tensor& min = context->input(2);
|
||||
OP_REQUIRES(context, min.dim_size(0) == depth,
|
||||
InvalidArgument("min has incorrect size, expected ", depth,
|
||||
" was ", min.dim_size(0)));
|
||||
const Tensor& max = context->input(3);
|
||||
OP_REQUIRES(context, max.dim_size(0) == depth,
|
||||
InvalidArgument("max has incorrect size, expected ", depth,
|
||||
" was ", max.dim_size(0)));
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Tensor* check_min_max = check_min_max_handle_.AccessTensor(context);
|
||||
#endif
|
||||
|
||||
Tensor* grad_wrt_input;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &grad_wrt_input));
|
||||
|
||||
TensorShape min_max_shape({input.dim_size(input.dims() - 1)});
|
||||
Tensor* grad_wrt_min;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(1, min_max_shape, &grad_wrt_min));
|
||||
|
||||
Tensor* grad_wrt_max;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(2, min_max_shape, &grad_wrt_max));
|
||||
|
||||
switch (input.dims()) {
|
||||
case 4: {
|
||||
FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(), input.dim_size(0),
|
||||
input.dim_size(1), input.dim_size(2), input.dim_size(3),
|
||||
gradient.flat<float>(), input.flat<float>(),
|
||||
min.vec<float>(), max.vec<float>(),
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max->scalar<bool>(),
|
||||
#endif
|
||||
grad_wrt_input->flat<float>(),
|
||||
grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(),
|
||||
input.dim_size(0), input.dim_size(1),
|
||||
gradient.flat<float>(), input.flat<float>(),
|
||||
min.vec<float>(), max.vec<float>(),
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max->scalar<bool>(),
|
||||
#endif
|
||||
grad_wrt_input->flat<float>(),
|
||||
grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<Device> functor;
|
||||
functor(context->eigen_device<Device>(),
|
||||
gradient.vec<float>(), input.vec<float>(),
|
||||
min.vec<float>(), max.vec<float>(),
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max->scalar<bool>(),
|
||||
#endif
|
||||
grad_wrt_input->vec<float>(),
|
||||
grad_wrt_min->vec<float>(), grad_wrt_max->vec<float>());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
context->SetStatus(InvalidArgument("Only inputs of dimensions 1, 2 or "
|
||||
"4 supported, was: ", input.dims()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
PersistentTensor check_min_max_handle_;
|
||||
#endif
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Device(DEVICE_CPU),
|
||||
FakeQuantWithMinMaxVarsPerChannelOp<CPUDevice>);
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Device(DEVICE_CPU),
|
||||
FakeQuantWithMinMaxVarsPerChannelGradientOp<CPUDevice>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
template <>
|
||||
void FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d,
|
||||
typename TTypes<float>::ConstVec inputs,
|
||||
typename TTypes<float>::ConstVec min,
|
||||
typename TTypes<float>::ConstVec max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
typename TTypes<bool>::Scalar check_min_max,
|
||||
#endif
|
||||
typename TTypes<float>::Vec outputs);
|
||||
extern template struct FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>;
|
||||
|
||||
template <>
|
||||
void FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d, const Index batch_size, const Index depth,
|
||||
typename TTypes<float>::ConstFlat inputs,
|
||||
typename TTypes<float>::ConstFlat min,
|
||||
typename TTypes<float>::ConstFlat max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
typename TTypes<bool>::Scalar check_min_max,
|
||||
#endif
|
||||
typename TTypes<float>::Flat outputs);
|
||||
extern template struct FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>;
|
||||
|
||||
template <>
|
||||
void FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d, const Index batch_size, const Index height,
|
||||
const Index width, const Index depth,
|
||||
typename TTypes<float>::ConstFlat inputs,
|
||||
typename TTypes<float>::ConstFlat min,
|
||||
typename TTypes<float>::ConstFlat max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
typename TTypes<bool>::Scalar check_min_max,
|
||||
#endif
|
||||
typename TTypes<float>::Flat outputs);
|
||||
extern template struct FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>;
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("min")
|
||||
.HostMemory("max"),
|
||||
FakeQuantWithMinMaxVarsPerChannelOp<GPUDevice>);
|
||||
|
||||
template <>
|
||||
void FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d,
|
||||
typename TTypes<float>::ConstVec gradients,
|
||||
typename TTypes<float>::ConstVec inputs,
|
||||
typename TTypes<float>::ConstVec min,
|
||||
typename TTypes<float>::ConstVec max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
typename TTypes<bool>::Scalar check_min_max,
|
||||
#endif
|
||||
typename TTypes<float>::Vec backprops_wrt_input,
|
||||
typename TTypes<float>::Vec backprop_wrt_min,
|
||||
typename TTypes<float>::Vec backprop_wrt_max);
|
||||
extern template struct
|
||||
FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
|
||||
|
||||
template <>
|
||||
void FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d, const Index batch_size, const Index depth,
|
||||
typename TTypes<float>::ConstFlat gradients,
|
||||
typename TTypes<float>::ConstFlat inputs,
|
||||
typename TTypes<float>::ConstVec min,
|
||||
typename TTypes<float>::ConstVec max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
typename TTypes<bool>::Scalar check_min_max,
|
||||
#endif
|
||||
typename TTypes<float>::Flat backprops_wrt_input,
|
||||
typename TTypes<float>::Vec backprop_wrt_min,
|
||||
typename TTypes<float>::Vec backprop_wrt_max);
|
||||
extern template struct
|
||||
FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
|
||||
|
||||
template <>
|
||||
void FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>::operator()(
|
||||
const GPUDevice& d, const Index batch_size, const Index height,
|
||||
const Index width, const Index depth,
|
||||
typename TTypes<float>::ConstFlat gradients,
|
||||
typename TTypes<float>::ConstFlat inputs,
|
||||
typename TTypes<float>::ConstVec min,
|
||||
typename TTypes<float>::ConstVec max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
typename TTypes<bool>::Scalar check_min_max,
|
||||
#endif
|
||||
typename TTypes<float>::Flat backprops_wrt_input,
|
||||
typename TTypes<float>::Vec backprop_wrt_min,
|
||||
typename TTypes<float>::Vec backprop_wrt_max);
|
||||
extern template struct
|
||||
FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("min")
|
||||
.HostMemory("max"),
|
||||
FakeQuantWithMinMaxVarsPerChannelGradientOp<GPUDevice>);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
434
tensorflow/core/kernels/fake_quant_ops_functor.h
Normal file
434
tensorflow/core/kernels/fake_quant_ops_functor.h
Normal file
@ -0,0 +1,434 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#define EIGEN_STACK_ALLOCATION_LIMIT 0
|
||||
#define EIGEN_USE_THREADS
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static constexpr int kSteps = 255;
|
||||
static constexpr float kStepsFloat = static_cast<float>(kSteps);
|
||||
|
||||
// Gymnastics with nudged zero point is to ensure that real zero maps to
|
||||
// an integer, which is required for e.g. zero-padding in convolutional layers.
|
||||
// Returns (nudged_min, nudged_max, nudged_scale).
|
||||
template <typename Device>
|
||||
std::tuple<float, float, float> Nudge(const float min, const float max) {
|
||||
const float scale = (max - min) / (kStepsFloat - 0.0f);
|
||||
const float zero_point_from_min = 0.0f - min / scale;
|
||||
const uint8 nudged_zero_point = [zero_point_from_min] {
|
||||
if (zero_point_from_min < 0.0f) {
|
||||
return static_cast<uint8>(0);
|
||||
} else if (zero_point_from_min > kStepsFloat) {
|
||||
return static_cast<uint8>(kSteps);
|
||||
} else {
|
||||
return static_cast<uint8>(std::round(zero_point_from_min));
|
||||
}
|
||||
}();
|
||||
|
||||
const float nudged_min = (0.0f - nudged_zero_point) * scale;
|
||||
const float nudged_max = (kStepsFloat - nudged_zero_point) * scale;
|
||||
return std::make_tuple(nudged_min, nudged_max, scale);
|
||||
}
|
||||
|
||||
template<typename T> using ConstScalar =
|
||||
typename tensorflow::TTypes<T>::ConstScalar;
|
||||
template<typename T> using Scalar = typename tensorflow::TTypes<T>::Scalar;
|
||||
template<typename T> using ConstVec = typename tensorflow::TTypes<T>::ConstVec;
|
||||
template<typename T> using Vec = typename tensorflow::TTypes<T>::Vec;
|
||||
template<typename T> using ConstFlat =
|
||||
typename tensorflow::TTypes<T>::ConstFlat;
|
||||
template<typename T> using Flat = typename tensorflow::TTypes<T>::Flat;
|
||||
|
||||
// Functor called by FakeQuantWithMinMaxArgsOp to do the work. Compiles both
|
||||
// for CPU and GPU.
|
||||
template <typename Device>
|
||||
struct FakeQuantWithMinMaxArgsFunctor {
|
||||
void operator()(const Device& d, ConstFlat<float> inputs,
|
||||
const float min, const float max, Flat<float> outputs) {
|
||||
eigen_assert(min <= 0.0f && "min should be <= 0.0");
|
||||
eigen_assert(max >= 0.0f && "max should be >= 0.0");
|
||||
eigen_assert(min < max && "min should be < max");
|
||||
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) = Nudge<Device>(min, max);
|
||||
const float inv_nudged_scale = 1.0f / nudged_scale;
|
||||
|
||||
auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
auto clamped_shifted = clamped - nudged_min;
|
||||
outputs.device(d) = (clamped_shifted * inv_nudged_scale + 0.5f).floor() *
|
||||
nudged_scale + nudged_min;
|
||||
}
|
||||
};
|
||||
|
||||
// Functor called by FakeQuantWithMinMaxArgsGradientOp to do the work. Compiles
|
||||
// both for CPU and GPU.
|
||||
template <typename Device>
|
||||
struct FakeQuantWithMinMaxArgsGradientFunctor {
|
||||
void operator()(const Device& d, ConstFlat<float> gradients,
|
||||
ConstFlat<float> inputs, const float min, const float max,
|
||||
Flat<float> backprops) {
|
||||
eigen_assert(min <= 0.0f && "min should be <= 0.0");
|
||||
eigen_assert(max >= 0.0f && "max should be >= 0.0");
|
||||
eigen_assert(min < max && "min should be < max");
|
||||
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) = Nudge<Device>(min, max);
|
||||
|
||||
auto between_nudged_min_max = (inputs >= nudged_min && inputs <= nudged_max)
|
||||
.select(inputs.constant(1.0f), inputs.constant(0.0f));
|
||||
backprops.device(d) = gradients * between_nudged_min_max;
|
||||
}
|
||||
};
|
||||
|
||||
// Functor called by FakeQuantWithMinMaxVarsOp to do the work. Compiles both
|
||||
// for CPU and GPU.
|
||||
template <typename Device>
|
||||
struct FakeQuantWithMinMaxVarsFunctor {
|
||||
void operator()(const Device& d, ConstFlat<float> inputs,
|
||||
ConstScalar<float> min, ConstScalar<float> max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Scalar<bool> check_min_max,
|
||||
#endif
|
||||
Flat<float> outputs) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max.device(d) = (min <= 0.0f).all();
|
||||
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (max >= 0.0f).all();
|
||||
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (min < max).all();
|
||||
eigen_assert(check_min_max() && "min should be < max coeff-wise");
|
||||
#endif
|
||||
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) =
|
||||
Nudge<Device>(min(), max());
|
||||
const auto nudged_scale_repl = inputs.constant(nudged_scale);
|
||||
|
||||
const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
const auto clamped_shifted = clamped - nudged_min;
|
||||
outputs.device(d) = (clamped_shifted / nudged_scale_repl + 0.5f).floor() *
|
||||
nudged_scale_repl + nudged_min;
|
||||
}
|
||||
};
|
||||
|
||||
// Functor called by FakeQuantWithMinMaxVarsGradientOp to do the work. Compiles
|
||||
// both for CPU and GPU.
|
||||
template <typename Device>
|
||||
struct FakeQuantWithMinMaxVarsGradientFunctor {
|
||||
void operator()(const Device& d,
|
||||
ConstFlat<float> gradients, ConstFlat<float> inputs,
|
||||
ConstScalar<float> min, ConstScalar<float> max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Scalar<bool> check_min_max,
|
||||
#endif
|
||||
Flat<float> backprops_wrt_input,
|
||||
Scalar<float> backprop_wrt_min,
|
||||
Scalar<float> backprop_wrt_max) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max.device(d) = (min <= 0.0f).all();
|
||||
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (max >= 0.0f).all();
|
||||
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (min < max).all();
|
||||
eigen_assert(check_min_max() && "min should be < max coeff-wise");
|
||||
#endif
|
||||
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) =
|
||||
Nudge<Device>(min(), max());
|
||||
|
||||
const auto between_min_max = (inputs >= nudged_min && inputs <= nudged_max)
|
||||
.select(inputs.constant(1.0f), inputs.constant(0.0f));
|
||||
backprops_wrt_input.device(d) = gradients * between_min_max;
|
||||
|
||||
const auto below_min = (inputs < nudged_min)
|
||||
.select(inputs.constant(1.0f), inputs.constant(0.0f));
|
||||
backprop_wrt_min.device(d) = (gradients * below_min).sum();
|
||||
|
||||
const auto above_max = (inputs > nudged_max)
|
||||
.select(inputs.constant(1.0f), inputs.constant(0.0f));
|
||||
backprop_wrt_max.device(d) = (gradients * above_max).sum();
|
||||
}
|
||||
};
|
||||
|
||||
using Index = typename tensorflow::TTypes<float>::ConstTensor::Index;
|
||||
|
||||
// Functor called by FakeQuantWithMinMaxVarsPerChannelOp to do the work.
|
||||
// Compiles both for CPU and GPU.
|
||||
//
|
||||
// Already verified: inputs, outputs, min, max are of shape [d].
|
||||
template <typename Device>
|
||||
struct FakeQuant1WithMinMaxVarsPerChannelFunctor {
|
||||
void operator()(const Device& d, ConstVec<float> inputs,
|
||||
ConstVec<float> min, ConstVec<float> max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Scalar<bool> check_min_max,
|
||||
#endif
|
||||
Vec<float> outputs) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max.device(d) = (min <= 0.0f).all();
|
||||
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (max >= 0.0f).all();
|
||||
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (min < max).all();
|
||||
eigen_assert(check_min_max() && "min should be < max coeff-wise");
|
||||
#endif
|
||||
|
||||
for (Index i = 0; i < min.size(); ++i) {
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) =
|
||||
Nudge<Device>(min(i), max(i));
|
||||
const float clamped =
|
||||
std::max(std::min(inputs(i), nudged_max), nudged_min);
|
||||
const float clamped_shifted = clamped - nudged_min;
|
||||
|
||||
outputs(i) = std::round(clamped_shifted / nudged_scale) * nudged_scale +
|
||||
nudged_min;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Already verified: inputs, outputs are of shape [b, d], min, max are of shape
|
||||
// [d].
|
||||
template <typename Device>
|
||||
struct FakeQuant2WithMinMaxVarsPerChannelFunctor {
|
||||
void operator()(const Device& d, const Index batch_size, const Index depth,
|
||||
ConstFlat<float> inputs,
|
||||
ConstVec<float> min, ConstVec<float> max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Scalar<bool> check_min_max,
|
||||
#endif
|
||||
Flat<float> outputs) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max.device(d) = (min <= 0.0f).all();
|
||||
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (max >= 0.0f).all();
|
||||
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (min < max).all();
|
||||
eigen_assert(check_min_max() && "min should be < max coeff-wise");
|
||||
#endif
|
||||
|
||||
Eigen::DSizes<Index, 2> restored(batch_size, depth);
|
||||
const auto inputs_restored = inputs.reshape(restored);
|
||||
for (Index i = 0; i < min.size(); ++i) {
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) =
|
||||
Nudge<Device>(min(i), max(i));
|
||||
const auto clamped = inputs_restored.chip<1>(i)
|
||||
.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
const auto clamped_shifted = clamped - nudged_min;
|
||||
|
||||
outputs.reshape(restored).chip<1>(i).device(d) =
|
||||
(clamped_shifted / nudged_scale + 0.5f).floor() * nudged_scale +
|
||||
nudged_min;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Already verified: inputs, outputs are of shape [b, h, w, d], min, max are
|
||||
// of shape [d].
|
||||
template <typename Device>
|
||||
struct FakeQuant4WithMinMaxVarsPerChannelFunctor {
|
||||
void operator()(const Device& d, const Index batch_size, const Index height,
|
||||
const Index width, const Index depth,
|
||||
ConstFlat<float> inputs,
|
||||
ConstVec<float> min, ConstVec<float> max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Scalar<bool> check_min_max,
|
||||
#endif
|
||||
Flat<float> outputs) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max.device(d) = (min <= 0.0f).all();
|
||||
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (max >= 0.0f).all();
|
||||
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (min < max).all();
|
||||
eigen_assert(check_min_max() && "min should be < max coeff-wise");
|
||||
#endif
|
||||
|
||||
Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth);
|
||||
const auto inputs_restored = inputs.reshape(restored);
|
||||
for (Index i = 0; i < min.size(); ++i) {
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) =
|
||||
Nudge<Device>(min(i), max(i));
|
||||
const auto clamped = inputs_restored.chip<3>(i)
|
||||
.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
const auto clamped_shifted = clamped - nudged_min;
|
||||
|
||||
outputs.reshape(restored).chip<3>(i).device(d) =
|
||||
(clamped_shifted / nudged_scale + 0.5f).floor() * nudged_scale +
|
||||
nudged_min;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Functor called by FakeQuantWithMinMaxVarsPerChannelGradientOp to do the work.
|
||||
// Compiles both for CPU and GPU.
|
||||
//
|
||||
// Already verified: gradients, inputs, outputs, min, max, backprops_wrt_input,
|
||||
// backprop_wrt_min, backprop_wrt_max are of shape [d].
|
||||
template <typename Device>
|
||||
struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor {
|
||||
void operator()(const Device& d,
|
||||
ConstVec<float> gradients, ConstVec<float> inputs,
|
||||
ConstVec<float> min, ConstVec<float> max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Scalar<bool> check_min_max,
|
||||
#endif
|
||||
Vec<float> backprops_wrt_input, Vec<float> backprop_wrt_min,
|
||||
Vec<float> backprop_wrt_max) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max.device(d) = (min <= 0.0f).all();
|
||||
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (max >= 0.0f).all();
|
||||
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (min < max).all();
|
||||
eigen_assert(check_min_max() && "min should be < max coeff-wise");
|
||||
#endif
|
||||
|
||||
for (Index i = 0; i < min.size(); ++i) {
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) =
|
||||
Nudge<Device>(min(i), max(i));
|
||||
|
||||
const bool between_min_max =
|
||||
inputs(i) >= nudged_min && inputs(i) <= nudged_max;
|
||||
backprops_wrt_input(i) = between_min_max ? gradients(i) : 0.0f;
|
||||
|
||||
const bool below_min = inputs(i) < nudged_min;
|
||||
backprop_wrt_min(i) = below_min ? gradients(i) : 0.0f;
|
||||
|
||||
const bool above_max = inputs(i) > nudged_max;
|
||||
backprop_wrt_max(i) = above_max ? gradients(i) : 0.0f;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Already verified: gradients, inputs, backprops_wrt_input are of shape [b, d],
|
||||
// min, max, backprop_wrt_min, backprop_wrt_max are of shape [d].
|
||||
template <typename Device>
|
||||
struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor {
|
||||
void operator()(const Device& d, const Index batch_size, const Index depth,
|
||||
ConstFlat<float> gradients, ConstFlat<float> inputs,
|
||||
ConstVec<float> min, ConstVec<float> max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Scalar<bool> check_min_max,
|
||||
#endif
|
||||
Flat<float> backprops_wrt_input,
|
||||
Vec<float> backprop_wrt_min, Vec<float> backprop_wrt_max) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max.device(d) = (min <= 0.0f).all();
|
||||
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (max >= 0.0f).all();
|
||||
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (min < max).all();
|
||||
eigen_assert(check_min_max() && "min should be < max coeff-wise");
|
||||
#endif
|
||||
|
||||
Eigen::DSizes<Index, 2> restored(batch_size, depth);
|
||||
const auto gradients_restored = gradients.reshape(restored);
|
||||
const auto inputs_restored = inputs.reshape(restored);
|
||||
for (Index i = 0; i < min.size(); ++i) {
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) =
|
||||
Nudge<Device>(min(i), max(i));
|
||||
const auto gradients_chip = gradients_restored.chip<1>(i);
|
||||
const auto inputs_chip = inputs_restored.chip<1>(i);
|
||||
|
||||
const auto between_min_max =
|
||||
(inputs_chip >= nudged_min && inputs_chip <= nudged_max)
|
||||
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
|
||||
backprops_wrt_input.reshape(restored).chip<1>(i).device(d) =
|
||||
gradients_chip * between_min_max;
|
||||
|
||||
const auto below_min = (inputs_chip < nudged_min)
|
||||
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
|
||||
Eigen::DSizes<Index, 1> reduce(0);
|
||||
backprop_wrt_min.chip<0>(i).device(d) =
|
||||
(gradients_chip * below_min).sum(reduce);
|
||||
|
||||
const auto above_max = (inputs_chip > nudged_max)
|
||||
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
|
||||
backprop_wrt_max.chip<0>(i).device(d) =
|
||||
(gradients_chip * above_max).sum(reduce);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Already verified: gradients, inputs, backprops_wrt_input are of shape
|
||||
// [b, h, w, d], min, max, backprop_wrt_min, backprop_wrt_max are of shape [d].
|
||||
template <typename Device>
|
||||
struct FakeQuant4WithMinMaxVarsPerChannelGradientFunctor {
|
||||
void operator()(const Device& d, const Index batch_size, const Index height,
|
||||
const Index width, const Index depth,
|
||||
ConstFlat<float> gradients, ConstFlat<float> inputs,
|
||||
ConstVec<float> min, ConstVec<float> max,
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
Scalar<bool> check_min_max,
|
||||
#endif
|
||||
Flat<float> backprops_wrt_input,
|
||||
Vec<float> backprop_wrt_min, Vec<float> backprop_wrt_max) {
|
||||
#ifndef FAKE_QUANT_NO_DEBUG
|
||||
check_min_max.device(d) = (min <= 0.0f).all();
|
||||
eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (max >= 0.0f).all();
|
||||
eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise");
|
||||
check_min_max.device(d) = (min < max).all();
|
||||
eigen_assert(check_min_max() && "min should be < max coeff-wise");
|
||||
#endif
|
||||
|
||||
Eigen::DSizes<Index, 4> restored(batch_size, height, width, depth);
|
||||
const auto gradients_restored = gradients.reshape(restored);
|
||||
const auto inputs_restored = inputs.reshape(restored);
|
||||
for (Index i = 0; i < min.size(); ++i) {
|
||||
float nudged_min, nudged_max, nudged_scale;
|
||||
std::tie(nudged_min, nudged_max, nudged_scale) =
|
||||
Nudge<Device>(min(i), max(i));
|
||||
const auto gradients_chip = gradients_restored.chip<3>(i);
|
||||
const auto inputs_chip = inputs_restored.chip<3>(i);
|
||||
|
||||
const auto between_min_max =
|
||||
(inputs_chip >= nudged_min && inputs_chip <= nudged_max)
|
||||
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
|
||||
backprops_wrt_input.reshape(restored).chip<3>(i).device(d) =
|
||||
gradients_chip * between_min_max;
|
||||
|
||||
const auto below_min = (inputs_chip < nudged_min)
|
||||
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
|
||||
Eigen::DSizes<Index, 3> reduce(0, 1, 2);
|
||||
backprop_wrt_min.chip<0>(i).device(d) =
|
||||
(gradients_chip * below_min).sum(reduce);
|
||||
|
||||
const auto above_max = (inputs_chip > nudged_max)
|
||||
.select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f));
|
||||
backprop_wrt_max.chip<0>(i).device(d) =
|
||||
(gradients_chip * above_max).sum(reduce);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
|
41
tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc
Normal file
41
tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define FAKE_QUANT_NO_DEBUG
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/kernels/fake_quant_ops_functor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Just instantiate GPU functor implementations.
|
||||
template struct FakeQuantWithMinMaxArgsFunctor<GPUDevice>;
|
||||
template struct FakeQuantWithMinMaxArgsGradientFunctor<GPUDevice>;
|
||||
template struct FakeQuantWithMinMaxVarsFunctor<GPUDevice>;
|
||||
template struct FakeQuantWithMinMaxVarsGradientFunctor<GPUDevice>;
|
||||
template struct FakeQuant1WithMinMaxVarsPerChannelFunctor<GPUDevice>;
|
||||
template struct FakeQuant2WithMinMaxVarsPerChannelFunctor<GPUDevice>;
|
||||
template struct FakeQuant4WithMinMaxVarsPerChannelFunctor<GPUDevice>;
|
||||
template struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
|
||||
template struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
|
||||
template struct FakeQuant4WithMinMaxVarsPerChannelGradientFunctor<GPUDevice>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
821
tensorflow/core/kernels/fake_quant_ops_test.cc
Normal file
821
tensorflow/core/kernels/fake_quant_ops_test.cc
Normal file
@ -0,0 +1,821 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using tensorflow::AllocatorAttributes;
|
||||
using tensorflow::DT_FLOAT;
|
||||
using tensorflow::NodeDefBuilder;
|
||||
using tensorflow::OpsTestBase;
|
||||
using tensorflow::Tensor;
|
||||
using tensorflow::TensorShape;
|
||||
using tensorflow::test::ExpectClose;
|
||||
using tensorflow::test::FillValues;
|
||||
|
||||
class QuantOpsTest : public OpsTestBase {
|
||||
protected:
|
||||
void AddRandomInput(const TensorShape& shape) {
|
||||
CHECK_GT(input_types_.size(), inputs_.size())
|
||||
<< "Adding more inputs than types; perhaps you need to call MakeOp";
|
||||
Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DT_FLOAT, shape);
|
||||
input->flat<float>().setRandom();
|
||||
tensors_.push_back(input);
|
||||
bool is_ref = IsRefType(input_types_[inputs_.size()]);
|
||||
if (is_ref) {
|
||||
CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), DT_FLOAT);
|
||||
inputs_.push_back({&lock_for_refs_, input});
|
||||
} else {
|
||||
CHECK_EQ(input_types_[inputs_.size()], DT_FLOAT);
|
||||
inputs_.push_back({nullptr, input});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(QuantOpsTest, WithArgsNoNudging) {
|
||||
// Original quantization range: [-10 + 0 / 4, -10 + 255 / 4], scale: 1/4.
|
||||
// Original zero point: 40, no nudging necessary.
|
||||
// Expected quantized values: -10.0, -10.25, ..., 53.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Attr("min", -10.0f)
|
||||
.Attr("max", 53.75f)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected,
|
||||
{-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithArgsNudgedZeroIs0) {
|
||||
// Original quantization range: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.4, nudged to 0.
|
||||
// Nudged range: [0.0; 63.75].
|
||||
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Attr("min", -0.1f)
|
||||
.Attr("max", 63.65f)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.1f, 0.0f, 0.1f, 0.25f, 63.75f, 63.8f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected, {0.0f, 0.0f, 0.0f, 0.25f, 63.75f, 63.75f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithArgsNudgedZeroIs1) {
|
||||
// Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged range: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Attr("min", -0.125f)
|
||||
.Attr("max", 63.625f)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected, {-0.25f, -0.25f, -0.25f, 0.0f, 63.5f, 63.5f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithArgsNudgedZeroIs255) {
|
||||
// Original quantization range: [0.4 / 4 - 255 / 4, 0.4 / 4 + 0 / 4].
|
||||
// Scale: 1/4, original zero point: 254.6, nudged to 255.
|
||||
// Nudged range: [-63.75; 0.0].
|
||||
// Expected quantized values: -63.75, -63.5, -63.25, ..., 0.0.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Attr("min", -63.65f)
|
||||
.Attr("max", 0.1f)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-63.8f, -63.75f, -63.7f, -63.5f, 0.0f, 0.1f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected, {-63.75f, -63.75f, -63.75f, -63.5f, 0.0f, 0.0f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithArgsGradient) {
|
||||
// Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged range: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgsGradient")
|
||||
.Input(FakeInput(DT_FLOAT)) // gradient
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Attr("min", -0.125f)
|
||||
.Attr("max", 63.625f)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Upstream gradients.
|
||||
AddRandomInput(TensorShape({2, 3}));
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
auto input_flat = GetInput(0).flat<float>();
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected,
|
||||
{0.0f, input_flat(1), input_flat(2),
|
||||
input_flat(3), input_flat(4), 0.0f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsNoNudging) {
|
||||
// Original quantization range: [-10 + 0 / 4, -10 + 255 / 4], scale: 1/4.
|
||||
// Original zero point: 40, no nudging necessary.
|
||||
// Expected quantized values: -10.0, -10.25, ..., 53.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({}), {-10.0f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({}), {53.75f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected,
|
||||
{-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsNudgedZeroIs0) {
|
||||
// Original quantization range: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.4, nudged to 0.
|
||||
// Nudged range: [0.0; 63.75].
|
||||
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.1f, 0.0f, 0.1f, 0.25f, 63.75f, 63.8f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({}), {-0.1f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({}), {63.65f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected,
|
||||
{0.0f, 0.0f, 0.0f, 0.25f, 63.75f, 63.75f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsNudgedZeroIs1) {
|
||||
// Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged range: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({}), {-0.125f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({}), {63.625f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected,
|
||||
{-0.25f, -0.25f, -0.25f, 0.0f, 63.5f, 63.5f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsGradient) {
|
||||
// Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged range: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsGradient")
|
||||
.Input(FakeInput(DT_FLOAT)) // gradients
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Upstream gradients.
|
||||
AddRandomInput(TensorShape({2, 3}));
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({}), {-0.125f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({}), {63.625f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output_bprop_wrt_input = GetOutput(0);
|
||||
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
auto in_flat = GetInput(0).flat<float>();
|
||||
FillValues<float>(&expected_bprop_wrt_input,
|
||||
{0.0f, in_flat(1),
|
||||
in_flat(2), in_flat(3),
|
||||
in_flat(4), 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
|
||||
|
||||
Tensor* output_bprop_wrt_min = GetOutput(1);
|
||||
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({}));
|
||||
expected_bprop_wrt_min.flat<float>()(0) = in_flat(0);
|
||||
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
|
||||
|
||||
Tensor* output_bprop_wrt_max = GetOutput(2);
|
||||
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({}));
|
||||
expected_bprop_wrt_max.flat<float>()(0) = in_flat(5);
|
||||
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim1NudgedZeroIs0) {
|
||||
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.4, nudged to 0.
|
||||
// Nudged ranges: [0.0; 63.75].
|
||||
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({4}), {-0.1f, 0.0f, 63.75f, 63.8f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected, {0.0f, 0.0f, 63.75f, 63.75f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim1NudgedZeroIs1) {
|
||||
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged ranges: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({4}), {-0.26f, -0.25f, -0.24f, 63.6f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({4}),
|
||||
{-0.125f, -0.125f, -0.125f, -0.125f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({4}),
|
||||
{63.625f, 63.625f, 63.625f, 63.625f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected, {-0.25f, -0.25f, -0.25f, 63.5f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim2NudgedZeroIs0) {
|
||||
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.4, nudged to 0.
|
||||
// Nudged ranges: [0.0; 63.75].
|
||||
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.1f, 0.0f, 0.1f,
|
||||
0.25f, 63.75f, 63.8f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({3}), {-0.1f, -0.1f, -0.1f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({3}), {63.65f, 63.65f, 63.65f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected, {0.0f, 0.0f, 0.0f,
|
||||
0.25f, 63.75f, 63.75f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim2NudgedZeroIs1) {
|
||||
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged ranges: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.26f, -0.25f, -0.24f,
|
||||
0.0f, 63.5f, 63.6f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({3}), {-0.125f, -0.125f, -0.125f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({3}), {63.625f, 63.625f, 63.625f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
FillValues<float>(&expected, {-0.25f, -0.25f, -0.25f,
|
||||
0.0f, 63.5f, 63.5f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim4NudgedZeroIs0) {
|
||||
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.4, nudged to 0.
|
||||
// Nudged ranges: [0.0; 63.75].
|
||||
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
|
||||
{-0.1f, 0.0f, 0.1f, 0.25f,
|
||||
0.5f, 0.75f, 1.0f, 1.25f,
|
||||
1.5f, 1.75f, 2.0f, 2.25f,
|
||||
|
||||
63.0f, 63.25f, 63.5f, 63.7f,
|
||||
63.75f, 63.8f, 63.9f, 100.0f,
|
||||
100.0f, 100.0f, 100.0f, 1000.0f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4}));
|
||||
FillValues<float>(&expected,
|
||||
{0.0f, 0.0f, 0.0f, 0.25f,
|
||||
0.5f, 0.75f, 1.0f, 1.25f,
|
||||
1.5f, 1.75f, 2.0f, 2.25f,
|
||||
|
||||
63.0f, 63.25f, 63.5f, 63.75f,
|
||||
63.75f, 63.75f, 63.75f, 63.75f,
|
||||
63.75f, 63.75f, 63.75f, 63.75f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim4NudgedZeroIs1) {
|
||||
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged ranges: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
|
||||
{-0.3f, -0.25f, -0.2f, 0.0f,
|
||||
0.25f, 0.5f, 0.75f, 1.0f,
|
||||
1.25f, 1.5f, 1.75f, 2.0f,
|
||||
|
||||
63.0f, 63.25f, 63.4f, 63.5f,
|
||||
63.6f, 63.7f, 100.0f, 100.0f,
|
||||
100.0f, 100.0f, 100.0f, 1000.0f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({4}),
|
||||
{-0.125f, -0.125f, -0.125f, -0.125f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({4}),
|
||||
{63.625f, 63.625f, 63.625f, 63.625f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output = GetOutput(0);
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4}));
|
||||
FillValues<float>(&expected,
|
||||
{-0.25f, -0.25f, -0.25f, 0.0f,
|
||||
0.25f, 0.5f, 0.75f, 1.0f,
|
||||
1.25f, 1.5f, 1.75f, 2.0f,
|
||||
|
||||
63.0f, 63.25f, 63.5f, 63.5f,
|
||||
63.5f, 63.5f, 63.5f, 63.5f,
|
||||
63.5f, 63.5f, 63.5f, 63.5f});
|
||||
ExpectClose(expected, *output);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim1GradientNudgedZeroIs0) {
|
||||
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.4, nudged to 0.
|
||||
// Nudged ranges: [0.0; 63.75].
|
||||
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Input(FakeInput(DT_FLOAT)) // gradients
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Upstream gradients.
|
||||
AddRandomInput(TensorShape({4}));
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({4}), {-0.1f, 0.0f, 63.75f, 63.8f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output_bprop_wrt_input = GetOutput(0);
|
||||
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
auto grad_flat = GetInput(0).flat<float>();
|
||||
FillValues<float>(&expected_bprop_wrt_input,
|
||||
{0.0f, grad_flat(1), grad_flat(2), 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
|
||||
|
||||
Tensor* output_bprop_wrt_min = GetOutput(1);
|
||||
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected_bprop_wrt_min,
|
||||
{grad_flat(0), 0.0f, 0.0f, 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
|
||||
|
||||
Tensor* output_bprop_wrt_max = GetOutput(2);
|
||||
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected_bprop_wrt_max,
|
||||
{0.0f, 0.0f, 0.0f, grad_flat(3)});
|
||||
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim1GradientNudgedZeroIs1) {
|
||||
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged ranges: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Input(FakeInput(DT_FLOAT)) // gradients
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Upstream gradients.
|
||||
AddRandomInput(TensorShape({4}));
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({4}), {-0.3f, -0.25f, 63.5f, 63.6f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({4}),
|
||||
{-0.125f, -0.125f, -0.125f, -0.125f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({4}),
|
||||
{63.625f, 63.625f, 63.625f, 63.625f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output_bprop_wrt_input = GetOutput(0);
|
||||
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
auto grad_flat = GetInput(0).flat<float>();
|
||||
FillValues<float>(&expected_bprop_wrt_input,
|
||||
{0.0f, grad_flat(1), grad_flat(2), 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
|
||||
|
||||
Tensor* output_bprop_wrt_min = GetOutput(1);
|
||||
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected_bprop_wrt_min,
|
||||
{grad_flat(0), 0.0f, 0.0f, 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
|
||||
|
||||
Tensor* output_bprop_wrt_max = GetOutput(2);
|
||||
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected_bprop_wrt_max,
|
||||
{0.0f, 0.0f, 0.0f, grad_flat(3)});
|
||||
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim2GradientNudgedZeroIs0) {
|
||||
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.4, nudged to 0.
|
||||
// Nudged ranges: [0.0; 63.75].
|
||||
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Input(FakeInput(DT_FLOAT)) // gradients
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Upstream gradients.
|
||||
AddRandomInput(TensorShape({2, 3}));
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.1f, 0.0f, 0.1f,
|
||||
0.25f, 63.75f, 63.8f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({3}), {-0.1f, -0.1f, -0.1f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({3}), {63.65f, 63.65f, 63.65f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output_bprop_wrt_input = GetOutput(0);
|
||||
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
auto grad_flat = GetInput(0).flat<float>();
|
||||
FillValues<float>(&expected_bprop_wrt_input,
|
||||
{0.0f, grad_flat(1), grad_flat(2),
|
||||
grad_flat(3), grad_flat(4), 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
|
||||
|
||||
Tensor* output_bprop_wrt_min = GetOutput(1);
|
||||
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({3}));
|
||||
FillValues<float>(&expected_bprop_wrt_min,
|
||||
{grad_flat(0), 0.0f, 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
|
||||
|
||||
Tensor* output_bprop_wrt_max = GetOutput(2);
|
||||
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({3}));
|
||||
FillValues<float>(&expected_bprop_wrt_max,
|
||||
{0.0f, 0.0f, grad_flat(5)});
|
||||
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim2GradientNudgedZeroIs1) {
|
||||
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged ranges: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Input(FakeInput(DT_FLOAT)) // gradients
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Upstream gradients.
|
||||
AddRandomInput(TensorShape({2, 3}));
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({2, 3}),
|
||||
{-0.3f, -0.25f, -0.2f,
|
||||
0.0f, 63.5f, 63.6f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({3}), {-0.125f, -0.125f, -0.125f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({3}), {63.625f, 63.625f, 63.625f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output_bprop_wrt_input = GetOutput(0);
|
||||
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3}));
|
||||
auto grad_flat = GetInput(0).flat<float>();
|
||||
FillValues<float>(&expected_bprop_wrt_input,
|
||||
{0.0f, grad_flat(1), grad_flat(2),
|
||||
grad_flat(3), grad_flat(4), 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
|
||||
|
||||
Tensor* output_bprop_wrt_min = GetOutput(1);
|
||||
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({3}));
|
||||
FillValues<float>(&expected_bprop_wrt_min,
|
||||
{grad_flat(0), 0.0f, 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
|
||||
|
||||
Tensor* output_bprop_wrt_max = GetOutput(2);
|
||||
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({3}));
|
||||
FillValues<float>(&expected_bprop_wrt_max,
|
||||
{0.0f, 0.0f, grad_flat(5)});
|
||||
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim4GradientNudgedZeroIs0) {
|
||||
// Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.4, nudged to 0.
|
||||
// Nudged ranges: [0.0; 63.75].
|
||||
// Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Input(FakeInput(DT_FLOAT)) // gradients
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Upstream gradients.
|
||||
AddRandomInput(TensorShape({1, 2, 3, 4}));
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
|
||||
{-0.1f, 0.0f, 63.75f, 63.8f,
|
||||
-0.1f, 0.0f, 63.75f, 63.8f,
|
||||
-0.1f, 0.0f, 63.75f, 63.8f,
|
||||
|
||||
-0.1f, 0.0f, 63.75f, 63.8f,
|
||||
-0.1f, 0.0f, 63.75f, 63.8f,
|
||||
-0.1f, 0.0f, 63.75f, 63.8f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output_bprop_wrt_input = GetOutput(0);
|
||||
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT,
|
||||
TensorShape({1, 2, 3, 4}));
|
||||
auto grad_flat = GetInput(0).flat<float>();
|
||||
FillValues<float>(
|
||||
&expected_bprop_wrt_input,
|
||||
{0.0f, grad_flat(1), grad_flat(2), 0.0f,
|
||||
0.0f, grad_flat(5), grad_flat(6), 0.0f,
|
||||
0.0f, grad_flat(9), grad_flat(10), 0.0f,
|
||||
|
||||
0.0f, grad_flat(13), grad_flat(14), 0.0f,
|
||||
0.0f, grad_flat(17), grad_flat(18), 0.0f,
|
||||
0.0f, grad_flat(21), grad_flat(22), 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
|
||||
|
||||
Tensor* output_bprop_wrt_min = GetOutput(1);
|
||||
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected_bprop_wrt_min,
|
||||
{grad_flat(0) + grad_flat(4) + grad_flat(8) +
|
||||
grad_flat(12) + grad_flat(16) + grad_flat(20),
|
||||
0.0f, 0.0f, 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
|
||||
|
||||
Tensor* output_bprop_wrt_max = GetOutput(2);
|
||||
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected_bprop_wrt_max,
|
||||
{0.0f, 0.0f, 0.0f,
|
||||
grad_flat(3) + grad_flat(7) + grad_flat(11) +
|
||||
grad_flat(15) + grad_flat(19) + grad_flat(23)});
|
||||
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
|
||||
}
|
||||
|
||||
TEST_F(QuantOpsTest, WithVarsPerChannelDim4GradientNudgedZeroIs1) {
|
||||
// Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4].
|
||||
// Scale: 1/4, original zero point: 0.5, nudged to 1.
|
||||
// Nudged ranges: [-0.25; 63.5].
|
||||
// Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5.
|
||||
TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Input(FakeInput(DT_FLOAT)) // gradients
|
||||
.Input(FakeInput(DT_FLOAT)) // inputs
|
||||
.Input(FakeInput(DT_FLOAT)) // min
|
||||
.Input(FakeInput(DT_FLOAT)) // max
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
// Upstream gradients.
|
||||
AddRandomInput(TensorShape({1, 2, 3, 4}));
|
||||
// Downstream inputs.
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
|
||||
{-0.3f, -0.25f, 63.5f, 63.6f,
|
||||
-0.3f, -0.25f, 63.5f, 63.6f,
|
||||
-0.3f, -0.25f, 63.5f, 63.6f,
|
||||
|
||||
-0.3f, -0.25f, 63.5f, 63.6f,
|
||||
-0.3f, -0.25f, 63.5f, 63.6f,
|
||||
-0.3f, -0.25f, 63.5f, 63.6f});
|
||||
// Min.
|
||||
AddInputFromArray<float>(TensorShape({4}),
|
||||
{-0.125f, -0.125f, -0.125f, -0.125f});
|
||||
// Max.
|
||||
AddInputFromArray<float>(TensorShape({4}),
|
||||
{63.625f, 63.625f, 63.625f, 63.625f});
|
||||
|
||||
// Tested code.
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor* output_bprop_wrt_input = GetOutput(0);
|
||||
Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT,
|
||||
TensorShape({1, 2, 3, 4}));
|
||||
auto grad_flat = GetInput(0).flat<float>();
|
||||
FillValues<float>(&expected_bprop_wrt_input,
|
||||
{0.0f, grad_flat(1), grad_flat(2), 0.0f,
|
||||
0.0f, grad_flat(5), grad_flat(6), 0.0f,
|
||||
0.0f, grad_flat(9), grad_flat(10), 0.0f,
|
||||
|
||||
0.0f, grad_flat(13), grad_flat(14), 0.0f,
|
||||
0.0f, grad_flat(17), grad_flat(18), 0.0f,
|
||||
0.0f, grad_flat(21), grad_flat(22), 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input);
|
||||
|
||||
Tensor* output_bprop_wrt_min = GetOutput(1);
|
||||
Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected_bprop_wrt_min,
|
||||
{grad_flat(0) + grad_flat(4) + grad_flat(8) +
|
||||
grad_flat(12) + grad_flat(16) + grad_flat(20),
|
||||
0.0f, 0.0f, 0.0f});
|
||||
ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min);
|
||||
|
||||
Tensor* output_bprop_wrt_max = GetOutput(2);
|
||||
Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4}));
|
||||
FillValues<float>(&expected_bprop_wrt_max,
|
||||
{0.0f, 0.0f, 0.0f,
|
||||
grad_flat(3) + grad_flat(7) + grad_flat(11) +
|
||||
grad_flat(15) + grad_flat(19) + grad_flat(23)});
|
||||
ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -4383,6 +4383,117 @@ output_min: This value is copied from input_min.
|
||||
output_max: This value is copied from input_max.
|
||||
)Doc");
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxArgs")
|
||||
.Attr("min: float = -6.0")
|
||||
.Attr("max: float = 6.0")
|
||||
.Input("inputs: float")
|
||||
.Output("outputs: float")
|
||||
.Doc(R"doc(
|
||||
Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
|
||||
|
||||
Attributes [min; max] define the clamping range for the 'inputs' data. Op
|
||||
divides this range into 255 steps (total of 256 values), then replaces each
|
||||
'inputs' value with the closest of the quantized step values.
|
||||
|
||||
Quantization is called fake since the output is still in floating point.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
|
||||
.Attr("min: float = -6.0")
|
||||
.Attr("max: float = 6.0")
|
||||
.Input("gradients: float")
|
||||
.Input("inputs: float")
|
||||
.Output("backprops: float")
|
||||
.Doc(R"doc(
|
||||
Compute gradients for a FakeQuantWithMinMaxArgs operation.
|
||||
|
||||
gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation.
|
||||
inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation.
|
||||
backprops: Backpropagated gradients below the FakeQuantWithMinMaxArgs operation:
|
||||
`gradients * (inputs >= min && inputs <= max)`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxVars")
|
||||
.Input("inputs: float")
|
||||
.Input("min: float")
|
||||
.Input("max: float")
|
||||
.Output("outputs: float")
|
||||
.Doc(R"doc(
|
||||
Fake-quantize the 'inputs' tensor of type float and shape `[b, h, w, d]` via
|
||||
global float scalars `min` and `max` to 'outputs' tensor of same shape as
|
||||
`inputs`.
|
||||
|
||||
[min; max] is the clamping range for the 'inputs' data. Op divides this range
|
||||
into 255 steps (total of 256 values), then replaces each 'inputs' value with the
|
||||
closest of the quantized step values.
|
||||
|
||||
This operation has a gradient and thus allows for training `min` and `max` values.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
|
||||
.Input("gradients: float")
|
||||
.Input("inputs: float")
|
||||
.Input("min: float")
|
||||
.Input("max: float")
|
||||
.Output("backprops_wrt_input: float")
|
||||
.Output("backprop_wrt_min: float")
|
||||
.Output("backprop_wrt_max: float")
|
||||
.Doc(R"doc(
|
||||
Compute gradients for a FakeQuantWithMinMaxVars operation.
|
||||
|
||||
gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation.
|
||||
inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation.
|
||||
min, max: Quantization interval, scalar floats.
|
||||
backprops_wrt_input: Backpropagated gradients w.r.t. inputs:
|
||||
`gradients * (inputs >= min && inputs <= max)`.
|
||||
backprop_wrt_min: Backpropagated gradients w.r.t. min parameter:
|
||||
`sum(gradients * (inputs < min))`.
|
||||
backprop_wrt_max: Backpropagated gradients w.r.t. max parameter:
|
||||
`sum(gradients * (inputs > max))`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel")
|
||||
.Input("inputs: float")
|
||||
.Input("min: float")
|
||||
.Input("max: float")
|
||||
.Output("outputs: float")
|
||||
.Doc(R"doc(
|
||||
Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`,
|
||||
`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]`
|
||||
to 'outputs' tensor of same shape as `inputs`.
|
||||
|
||||
[min; max] is the clamping range for the 'inputs' data in the corresponding
|
||||
depth channel. Op divides this range into 255 steps (total of 256 values), then
|
||||
replaces each 'inputs' value with the closest of the quantized step values.
|
||||
|
||||
This operation has a gradient and thus allows for training `min` and `max` values.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
|
||||
.Input("gradients: float")
|
||||
.Input("inputs: float")
|
||||
.Input("min: float")
|
||||
.Input("max: float")
|
||||
.Output("backprops_wrt_input: float")
|
||||
.Output("backprop_wrt_min: float")
|
||||
.Output("backprop_wrt_max: float")
|
||||
.Doc(R"doc(
|
||||
Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.
|
||||
|
||||
gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation,
|
||||
shape one of: `[d]`, `[b, d]`, `[b, h, w, d]`.
|
||||
inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape
|
||||
same as `gradients`.
|
||||
min, max: Quantization interval, floats of shape `[d]`.
|
||||
backprops_wrt_input: Backpropagated gradients w.r.t. inputs, shape same as
|
||||
`inputs`:
|
||||
`gradients * (inputs >= min && inputs <= max)`.
|
||||
backprop_wrt_min: Backpropagated gradients w.r.t. min parameter, shape `[d]`:
|
||||
`sum_per_d(gradients * (inputs < min))`.
|
||||
backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`:
|
||||
`sum_per_d(gradients * (inputs > max))`.
|
||||
)doc");
|
||||
|
||||
// Deprecated op registrations:
|
||||
|
||||
// The following can be deleted after 10mar2017.
|
||||
|
@ -1905,7 +1905,6 @@ def _EditDistanceShape(op):
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[2, 5])
|
||||
|
||||
|
||||
# The remaining ops do not change the shape of their inputs.
|
||||
@ops.RegisterShape("Quantize")
|
||||
@ops.RegisterShape("Dequantize")
|
||||
def _QuantizeDequantizeShape(op):
|
||||
@ -1914,6 +1913,45 @@ def _QuantizeDequantizeShape(op):
|
||||
return common_shapes.unchanged_shape(op)
|
||||
|
||||
|
||||
@ops.RegisterShape("FakeQuantWithMinMaxArgs")
|
||||
def _FakeQuantWithMinMaxArgsShape(op):
|
||||
"""Shape function for FakeQuantWithMinMaxArgs op: preserve the input shape."""
|
||||
return [op.inputs[0].get_shape()]
|
||||
|
||||
|
||||
@ops.RegisterGradient("FakeQuantWithMinMaxArgs")
|
||||
def _FakeQuantWithMinMaxArgsGradient(op, grad):
|
||||
"""Gradient for FakeQuantWithMinMaxArgs op."""
|
||||
return fake_quant_with_min_max_args_gradient(grad, op.inputs[0])
|
||||
|
||||
|
||||
@ops.RegisterShape("FakeQuantWithMinMaxVars")
|
||||
def _FakeQuantWithMinMaxVarsShape(op):
|
||||
"""Shape function for FakeQuantWithMinMaxVars op: preserve the input shape."""
|
||||
return [op.inputs[0].get_shape()]
|
||||
|
||||
|
||||
@ops.RegisterGradient("FakeQuantWithMinMaxVars")
|
||||
def _FakeQuantWithMinMaxVarsGradient(op, grad):
|
||||
"""Gradient for FakeQuantWithMinMaxVars op."""
|
||||
return fake_quant_with_min_max_vars_gradient(grad, op.inputs[0], op.inputs[1],
|
||||
op.inputs[2])
|
||||
|
||||
|
||||
@ops.RegisterShape("FakeQuantWithMinMaxVarsPerChannel")
|
||||
def _FakeQuantWithMinMaxVarsPerChannelShape(op):
|
||||
"""Shape function for FakeQuantWithMinMaxVarsPerChannel op: input shape."""
|
||||
return [op.inputs[0].get_shape()]
|
||||
|
||||
|
||||
@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel")
|
||||
def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
|
||||
"""Gradient for FakeQuantWithMinMaxVarsPerChannel op."""
|
||||
return fake_quant_with_min_max_vars_per_channel_gradient(grad, op.inputs[0],
|
||||
op.inputs[1],
|
||||
op.inputs[2])
|
||||
|
||||
|
||||
ops.RegisterShape("ExtractImagePatches")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user