From 9fb15ea28bc7ba713fb7745d60336d7a9a8f89a6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Oct 2016 13:35:38 -0800 Subject: [PATCH] Add FakeQuant ops and kernels for use with quantized training. Change: 137081182 --- tensorflow/core/BUILD | 1 + tensorflow/core/kernels/BUILD | 34 + tensorflow/core/kernels/fake_quant_ops.cc | 580 +++++++++++++ .../core/kernels/fake_quant_ops_functor.h | 434 +++++++++ .../core/kernels/fake_quant_ops_gpu.cu.cc | 41 + .../core/kernels/fake_quant_ops_test.cc | 821 ++++++++++++++++++ tensorflow/core/ops/array_ops.cc | 111 +++ tensorflow/python/ops/array_ops.py | 40 +- 8 files changed, 2061 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/kernels/fake_quant_ops.cc create mode 100644 tensorflow/core/kernels/fake_quant_ops_functor.h create mode 100644 tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc create mode 100644 tensorflow/core/kernels/fake_quant_ops_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a2a998cf4dc..0845028b5b7 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -521,6 +521,7 @@ cc_library( "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:ctc_ops", "//tensorflow/core/kernels:data_flow", + "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5e90ac885bd..b31f92c22e9 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -563,6 +563,24 @@ tf_cc_test( ], ) +tf_cc_test( + name = "fake_quant_ops_test", + size = "small", + srcs = ["fake_quant_ops_test.cc"], + deps = [ + ":fake_quant_ops", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "fused_batch_norm_op_test", size = "small", @@ -1710,6 +1728,22 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "fake_quant_ops", + srcs = ["fake_quant_ops.cc"], + hdrs = ["fake_quant_ops_functor.h"], + gpu_srcs = [ + "fake_quant_ops_gpu.cu.cc", + "fake_quant_ops_functor.h", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], + alwayslink = 1, +) + tf_kernel_library( name = "fused_batch_norm_util", gpu_srcs = [ diff --git a/tensorflow/core/kernels/fake_quant_ops.cc b/tensorflow/core/kernels/fake_quant_ops.cc new file mode 100644 index 00000000000..41f9c218437 --- /dev/null +++ b/tensorflow/core/kernels/fake_quant_ops.cc @@ -0,0 +1,580 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#ifdef GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#define FAKE_QUANT_NO_DEBUG + +#include "tensorflow/core/kernels/fake_quant_ops_functor.h" + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/protobuf.h" + +using tensorflow::BinaryElementWiseOp; +using tensorflow::DEVICE_CPU; +#if GOOGLE_CUDA +using tensorflow::DEVICE_GPU; +#endif +using tensorflow::DT_BOOL; +using tensorflow::OpKernel; +using tensorflow::OpKernelConstruction; +using tensorflow::OpKernelContext; +using tensorflow::PersistentTensor; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::TTypes; // NOLINT This is needed in CUDA mode, do not remove. +using tensorflow::UnaryElementWiseOp; +using tensorflow::errors::InvalidArgument; + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +// ----------------------------------------------------------------------------- +// Implementation of FakeQuantWithMinMaxArgsOp, see its documentation in +// core/ops/array_ops.cc. +template +class FakeQuantWithMinMaxArgsOp + : public UnaryElementWiseOp> { + public: + typedef UnaryElementWiseOp> Base; + explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* context) + : Base::UnaryElementWiseOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("min", &min_)); + OP_REQUIRES_OK(context, context->GetAttr("max", &max_)); + OP_REQUIRES(context, min_ < max_, + InvalidArgument("min has to be smaller than max, was: ", min_, + " >= ", max_)); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + FakeQuantWithMinMaxArgsFunctor functor; + functor(context->eigen_device(), input.flat(), min_, max_, + output->flat()); + } + private: + float min_; + float max_; +}; + +// Implementation of FakeQuantWithMinMaxArgsGradientOp, see its documentation in +// core/ops/array_ops.cc. +template +class FakeQuantWithMinMaxArgsGradientOp + : public BinaryElementWiseOp> { + public: + typedef BinaryElementWiseOp> + Base; + explicit FakeQuantWithMinMaxArgsGradientOp(OpKernelConstruction* context) + : Base::BinaryElementWiseOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("min", &min_)); + OP_REQUIRES_OK(context, context->GetAttr("max", &max_)); + OP_REQUIRES(context, min_ < max_, + InvalidArgument("min has to be smaller than max, was: ", min_, + " >= ", max_)); + } + + template + void Operate(OpKernelContext* context, const Tensor& gradient, + const Tensor& input, Tensor* output) { + OperateNoTemplate(context, gradient, input, output); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& gradient, + const Tensor& input, Tensor* output) { + OP_REQUIRES(context, input.IsSameSize(gradient), + InvalidArgument("gradient and input must be the same size")); + FakeQuantWithMinMaxArgsGradientFunctor functor; + functor(context->eigen_device(), gradient.flat(), + input.flat(), min_, max_, output->flat()); + } + private: + float min_; + float max_; +}; + +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_CPU), + FakeQuantWithMinMaxArgsOp); +REGISTER_KERNEL_BUILDER( + Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_CPU), + FakeQuantWithMinMaxArgsGradientOp); + +#if GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; + +// Forward declarations for functor specializations for GPU. +template <> +void FakeQuantWithMinMaxArgsFunctor::operator()( + const GPUDevice& d, + typename TTypes::ConstFlat inputs, + const float min, const float max, + typename TTypes::Flat outputs); +extern template struct FakeQuantWithMinMaxArgsFunctor; +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxArgs").Device(DEVICE_GPU), + FakeQuantWithMinMaxArgsOp); + +template <> +void FakeQuantWithMinMaxArgsGradientFunctor::operator()( + const GPUDevice& d, + typename TTypes::ConstFlat gradients, + typename TTypes::ConstFlat inputs, + const float min, const float max, + typename TTypes::Flat backprops); +REGISTER_KERNEL_BUILDER( + Name("FakeQuantWithMinMaxArgsGradient").Device(DEVICE_GPU), + FakeQuantWithMinMaxArgsGradientOp); +#endif // GOOGLE_CUDA + +// ----------------------------------------------------------------------------- +// Implementation of FakeQuantWithMinMaxVarsOp, see its documentation in +// core/ops/array_ops.cc. +template +class FakeQuantWithMinMaxVarsOp : public OpKernel { + public: + explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* context) + : OpKernel::OpKernel(context) { +#ifndef FAKE_QUANT_NO_DEBUG + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_BOOL, {}, + &check_min_max_handle_, + nullptr)); +#endif + } + + void Compute(OpKernelContext* context) override { + CHECK_EQ(3, context->num_inputs()); + const Tensor& input = context->input(0); + const Tensor& min = context->input(1); + const Tensor& max = context->input(2); +#ifndef FAKE_QUANT_NO_DEBUG + Tensor* check_min_max = check_min_max_handle_.AccessTensor(context); +#endif + + Tensor* output; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + + FakeQuantWithMinMaxVarsFunctor functor; + functor(context->eigen_device(), input.flat(), + min.scalar(), max.scalar(), +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max->scalar(), +#endif + output->flat()); + } + + private: +#ifndef FAKE_QUANT_NO_DEBUG + PersistentTensor check_min_max_handle_; +#endif +}; + +// Implementation of FakeQuantWithMinMaxVarsGradientOp, see its documentation in +// core/ops/array_ops.cc. +template +class FakeQuantWithMinMaxVarsGradientOp : public OpKernel { + public: + explicit FakeQuantWithMinMaxVarsGradientOp(OpKernelConstruction* context) + : OpKernel::OpKernel(context) { +#ifndef FAKE_QUANT_NO_DEBUG + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_BOOL, {}, + &check_min_max_handle_, + nullptr)); +#endif + } + + void Compute(OpKernelContext* context) override { + CHECK_EQ(4, context->num_inputs()); + const Tensor& gradient = context->input(0); + const Tensor& input = context->input(1); + OP_REQUIRES(context, input.IsSameSize(gradient), + InvalidArgument("gradient and input must be the same size")); + const Tensor& min = context->input(2); + const Tensor& max = context->input(3); +#ifndef FAKE_QUANT_NO_DEBUG + Tensor* check_min_max = check_min_max_handle_.AccessTensor(context); +#endif + + Tensor* grad_wrt_input; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &grad_wrt_input)); + + TensorShape scalar_shape; + Tensor* grad_wrt_min; + OP_REQUIRES_OK(context, + context->allocate_output(1, scalar_shape, &grad_wrt_min)); + + Tensor* grad_wrt_max; + OP_REQUIRES_OK(context, + context->allocate_output(2, scalar_shape, &grad_wrt_max)); + + FakeQuantWithMinMaxVarsGradientFunctor functor; + functor(context->eigen_device(), gradient.flat(), + input.flat(), min.scalar(), max.scalar(), +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max->scalar(), +#endif + grad_wrt_input->flat(), grad_wrt_min->scalar(), + grad_wrt_max->scalar()); + } + + private: +#ifndef FAKE_QUANT_NO_DEBUG + PersistentTensor check_min_max_handle_; +#endif +}; + +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars").Device(DEVICE_CPU), + FakeQuantWithMinMaxVarsOp); +REGISTER_KERNEL_BUILDER( + Name("FakeQuantWithMinMaxVarsGradient").Device(DEVICE_CPU), + FakeQuantWithMinMaxVarsGradientOp); + +#if GOOGLE_CUDA +template <> +void FakeQuantWithMinMaxVarsFunctor::operator()( + const GPUDevice& d, + typename TTypes::ConstFlat inputs, + typename TTypes::ConstScalar min, + typename TTypes::ConstScalar max, +#ifndef FAKE_QUANT_NO_DEBUG + typename TTypes::Scalar check_min_max, +#endif + typename TTypes::Flat output); +extern template struct FakeQuantWithMinMaxVarsFunctor; +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVars") + .Device(DEVICE_GPU) + .HostMemory("min") + .HostMemory("max"), + FakeQuantWithMinMaxVarsOp); + +template <> +void FakeQuantWithMinMaxVarsGradientFunctor::operator()( + const GPUDevice& d, + typename TTypes::ConstFlat gradients, + typename TTypes::ConstFlat inputs, + typename TTypes::ConstScalar min, + typename TTypes::ConstScalar max, +#ifndef FAKE_QUANT_NO_DEBUG + typename TTypes::Scalar check_min_max, +#endif + typename TTypes::Flat backprops_wrt_input, + typename TTypes::Scalar backprop_wrt_min, + typename TTypes::Scalar backprop_wrt_max); +extern template struct FakeQuantWithMinMaxVarsGradientFunctor; +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsGradient") + .Device(DEVICE_GPU) + .HostMemory("min") + .HostMemory("max"), + FakeQuantWithMinMaxVarsGradientOp); +#endif // GOOGLE_CUDA + +// ----------------------------------------------------------------------------- +// Implementation of FakeQuantWithMinMaxVarsPerChannelOp, see its documentation +// in core/ops/array_ops.cc. +template +class FakeQuantWithMinMaxVarsPerChannelOp : public OpKernel { + public: + explicit FakeQuantWithMinMaxVarsPerChannelOp(OpKernelConstruction* context) + : OpKernel::OpKernel(context) { +#ifndef FAKE_QUANT_NO_DEBUG + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_BOOL, {}, + &check_min_max_handle_, + nullptr)); +#endif + } + + void Compute(OpKernelContext* context) override { + CHECK_EQ(3, context->num_inputs()); + const Tensor& input = context->input(0); + const int depth = input.dim_size(input.dims() - 1); // last dimension size. + const Tensor& min = context->input(1); + OP_REQUIRES(context, min.dim_size(0) == depth, + InvalidArgument("min has incorrect size, expected ", depth, + " was ", min.dim_size(0))); + const Tensor& max = context->input(2); + OP_REQUIRES(context, max.dim_size(0) == depth, + InvalidArgument("max has incorrect size, expected ", depth, + " was ", max.dim_size(0))); +#ifndef FAKE_QUANT_NO_DEBUG + Tensor* check_min_max = check_min_max_handle_.AccessTensor(context); +#endif + + Tensor* output; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + + switch (input.dims()) { + case 4: { + FakeQuant4WithMinMaxVarsPerChannelFunctor functor; + functor(context->eigen_device(), input.dim_size(0), + input.dim_size(1), input.dim_size(2), input.dim_size(3), + input.flat(), min.vec(), max.vec(), +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max->scalar(), +#endif + output->flat()); + break; + } + case 2: { + FakeQuant2WithMinMaxVarsPerChannelFunctor functor; + functor(context->eigen_device(), + input.dim_size(0), input.dim_size(1), + input.flat(), min.vec(), max.vec(), +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max->scalar(), +#endif + output->flat()); + break; + } + case 1: { + FakeQuant1WithMinMaxVarsPerChannelFunctor functor; + functor(context->eigen_device(), + input.vec(), min.vec(), max.vec(), +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max->scalar(), +#endif + output->vec()); + break; + } + default: + context->SetStatus(InvalidArgument("Only inputs of dimensions 1, 2 or " + "4 supported, was: ", input.dims())); + break; + } + } + + private: +#ifndef FAKE_QUANT_NO_DEBUG + PersistentTensor check_min_max_handle_; +#endif +}; + +// Implementation of FakeQuantWithMinMaxVarsPerChannelGradientOp, see its +// documentation in core/ops/array_ops.cc. +template +class FakeQuantWithMinMaxVarsPerChannelGradientOp : public OpKernel { + public: + explicit FakeQuantWithMinMaxVarsPerChannelGradientOp( + OpKernelConstruction* context) : OpKernel::OpKernel(context) { +#ifndef FAKE_QUANT_NO_DEBUG + OP_REQUIRES_OK(context, + context->allocate_persistent(DT_BOOL, {}, + &check_min_max_handle_, + nullptr)); +#endif + } + + void Compute(OpKernelContext* context) override { + CHECK_EQ(4, context->num_inputs()); + const Tensor& gradient = context->input(0); + const Tensor& input = context->input(1); + OP_REQUIRES(context, input.IsSameSize(gradient), + InvalidArgument("gradient and input must be the same size")); + const int depth = input.dim_size(input.dims() - 1); // last dimension size. + const Tensor& min = context->input(2); + OP_REQUIRES(context, min.dim_size(0) == depth, + InvalidArgument("min has incorrect size, expected ", depth, + " was ", min.dim_size(0))); + const Tensor& max = context->input(3); + OP_REQUIRES(context, max.dim_size(0) == depth, + InvalidArgument("max has incorrect size, expected ", depth, + " was ", max.dim_size(0))); +#ifndef FAKE_QUANT_NO_DEBUG + Tensor* check_min_max = check_min_max_handle_.AccessTensor(context); +#endif + + Tensor* grad_wrt_input; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &grad_wrt_input)); + + TensorShape min_max_shape({input.dim_size(input.dims() - 1)}); + Tensor* grad_wrt_min; + OP_REQUIRES_OK(context, + context->allocate_output(1, min_max_shape, &grad_wrt_min)); + + Tensor* grad_wrt_max; + OP_REQUIRES_OK(context, + context->allocate_output(2, min_max_shape, &grad_wrt_max)); + + switch (input.dims()) { + case 4: { + FakeQuant4WithMinMaxVarsPerChannelGradientFunctor functor; + functor(context->eigen_device(), input.dim_size(0), + input.dim_size(1), input.dim_size(2), input.dim_size(3), + gradient.flat(), input.flat(), + min.vec(), max.vec(), +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max->scalar(), +#endif + grad_wrt_input->flat(), + grad_wrt_min->vec(), grad_wrt_max->vec()); + break; + } + case 2: { + FakeQuant2WithMinMaxVarsPerChannelGradientFunctor functor; + functor(context->eigen_device(), + input.dim_size(0), input.dim_size(1), + gradient.flat(), input.flat(), + min.vec(), max.vec(), +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max->scalar(), +#endif + grad_wrt_input->flat(), + grad_wrt_min->vec(), grad_wrt_max->vec()); + break; + } + case 1: { + FakeQuant1WithMinMaxVarsPerChannelGradientFunctor functor; + functor(context->eigen_device(), + gradient.vec(), input.vec(), + min.vec(), max.vec(), +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max->scalar(), +#endif + grad_wrt_input->vec(), + grad_wrt_min->vec(), grad_wrt_max->vec()); + break; + } + default: + context->SetStatus(InvalidArgument("Only inputs of dimensions 1, 2 or " + "4 supported, was: ", input.dims())); + break; + } + } + + private: +#ifndef FAKE_QUANT_NO_DEBUG + PersistentTensor check_min_max_handle_; +#endif +}; + +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel") + .Device(DEVICE_CPU), + FakeQuantWithMinMaxVarsPerChannelOp); +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient") + .Device(DEVICE_CPU), + FakeQuantWithMinMaxVarsPerChannelGradientOp); + +#if GOOGLE_CUDA +template <> +void FakeQuant1WithMinMaxVarsPerChannelFunctor::operator()( + const GPUDevice& d, + typename TTypes::ConstVec inputs, + typename TTypes::ConstVec min, + typename TTypes::ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + typename TTypes::Scalar check_min_max, +#endif + typename TTypes::Vec outputs); +extern template struct FakeQuant1WithMinMaxVarsPerChannelFunctor; + +template <> +void FakeQuant2WithMinMaxVarsPerChannelFunctor::operator()( + const GPUDevice& d, const Index batch_size, const Index depth, + typename TTypes::ConstFlat inputs, + typename TTypes::ConstFlat min, + typename TTypes::ConstFlat max, +#ifndef FAKE_QUANT_NO_DEBUG + typename TTypes::Scalar check_min_max, +#endif + typename TTypes::Flat outputs); +extern template struct FakeQuant2WithMinMaxVarsPerChannelFunctor; + +template <> +void FakeQuant4WithMinMaxVarsPerChannelFunctor::operator()( + const GPUDevice& d, const Index batch_size, const Index height, + const Index width, const Index depth, + typename TTypes::ConstFlat inputs, + typename TTypes::ConstFlat min, + typename TTypes::ConstFlat max, +#ifndef FAKE_QUANT_NO_DEBUG + typename TTypes::Scalar check_min_max, +#endif + typename TTypes::Flat outputs); +extern template struct FakeQuant4WithMinMaxVarsPerChannelFunctor; + +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannel") + .Device(DEVICE_GPU) + .HostMemory("min") + .HostMemory("max"), + FakeQuantWithMinMaxVarsPerChannelOp); + +template <> +void FakeQuant1WithMinMaxVarsPerChannelGradientFunctor::operator()( + const GPUDevice& d, + typename TTypes::ConstVec gradients, + typename TTypes::ConstVec inputs, + typename TTypes::ConstVec min, + typename TTypes::ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + typename TTypes::Scalar check_min_max, +#endif + typename TTypes::Vec backprops_wrt_input, + typename TTypes::Vec backprop_wrt_min, + typename TTypes::Vec backprop_wrt_max); +extern template struct + FakeQuant1WithMinMaxVarsPerChannelGradientFunctor; + +template <> +void FakeQuant2WithMinMaxVarsPerChannelGradientFunctor::operator()( + const GPUDevice& d, const Index batch_size, const Index depth, + typename TTypes::ConstFlat gradients, + typename TTypes::ConstFlat inputs, + typename TTypes::ConstVec min, + typename TTypes::ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + typename TTypes::Scalar check_min_max, +#endif + typename TTypes::Flat backprops_wrt_input, + typename TTypes::Vec backprop_wrt_min, + typename TTypes::Vec backprop_wrt_max); +extern template struct + FakeQuant2WithMinMaxVarsPerChannelGradientFunctor; + +template <> +void FakeQuant4WithMinMaxVarsPerChannelGradientFunctor::operator()( + const GPUDevice& d, const Index batch_size, const Index height, + const Index width, const Index depth, + typename TTypes::ConstFlat gradients, + typename TTypes::ConstFlat inputs, + typename TTypes::ConstVec min, + typename TTypes::ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + typename TTypes::Scalar check_min_max, +#endif + typename TTypes::Flat backprops_wrt_input, + typename TTypes::Vec backprop_wrt_min, + typename TTypes::Vec backprop_wrt_max); +extern template struct + FakeQuant4WithMinMaxVarsPerChannelGradientFunctor; + +REGISTER_KERNEL_BUILDER(Name("FakeQuantWithMinMaxVarsPerChannelGradient") + .Device(DEVICE_GPU) + .HostMemory("min") + .HostMemory("max"), + FakeQuantWithMinMaxVarsPerChannelGradientOp); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h new file mode 100644 index 00000000000..d3f600cd824 --- /dev/null +++ b/tensorflow/core/kernels/fake_quant_ops_functor.h @@ -0,0 +1,434 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ + +#include + +#define EIGEN_STACK_ALLOCATION_LIMIT 0 +#define EIGEN_USE_THREADS +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +static constexpr int kSteps = 255; +static constexpr float kStepsFloat = static_cast(kSteps); + +// Gymnastics with nudged zero point is to ensure that real zero maps to +// an integer, which is required for e.g. zero-padding in convolutional layers. +// Returns (nudged_min, nudged_max, nudged_scale). +template +std::tuple Nudge(const float min, const float max) { + const float scale = (max - min) / (kStepsFloat - 0.0f); + const float zero_point_from_min = 0.0f - min / scale; + const uint8 nudged_zero_point = [zero_point_from_min] { + if (zero_point_from_min < 0.0f) { + return static_cast(0); + } else if (zero_point_from_min > kStepsFloat) { + return static_cast(kSteps); + } else { + return static_cast(std::round(zero_point_from_min)); + } + }(); + + const float nudged_min = (0.0f - nudged_zero_point) * scale; + const float nudged_max = (kStepsFloat - nudged_zero_point) * scale; + return std::make_tuple(nudged_min, nudged_max, scale); +} + +template using ConstScalar = + typename tensorflow::TTypes::ConstScalar; +template using Scalar = typename tensorflow::TTypes::Scalar; +template using ConstVec = typename tensorflow::TTypes::ConstVec; +template using Vec = typename tensorflow::TTypes::Vec; +template using ConstFlat = + typename tensorflow::TTypes::ConstFlat; +template using Flat = typename tensorflow::TTypes::Flat; + +// Functor called by FakeQuantWithMinMaxArgsOp to do the work. Compiles both +// for CPU and GPU. +template +struct FakeQuantWithMinMaxArgsFunctor { + void operator()(const Device& d, ConstFlat inputs, + const float min, const float max, Flat outputs) { + eigen_assert(min <= 0.0f && "min should be <= 0.0"); + eigen_assert(max >= 0.0f && "max should be >= 0.0"); + eigen_assert(min < max && "min should be < max"); + + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = Nudge(min, max); + const float inv_nudged_scale = 1.0f / nudged_scale; + + auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); + auto clamped_shifted = clamped - nudged_min; + outputs.device(d) = (clamped_shifted * inv_nudged_scale + 0.5f).floor() * + nudged_scale + nudged_min; + } +}; + +// Functor called by FakeQuantWithMinMaxArgsGradientOp to do the work. Compiles +// both for CPU and GPU. +template +struct FakeQuantWithMinMaxArgsGradientFunctor { + void operator()(const Device& d, ConstFlat gradients, + ConstFlat inputs, const float min, const float max, + Flat backprops) { + eigen_assert(min <= 0.0f && "min should be <= 0.0"); + eigen_assert(max >= 0.0f && "max should be >= 0.0"); + eigen_assert(min < max && "min should be < max"); + + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = Nudge(min, max); + + auto between_nudged_min_max = (inputs >= nudged_min && inputs <= nudged_max) + .select(inputs.constant(1.0f), inputs.constant(0.0f)); + backprops.device(d) = gradients * between_nudged_min_max; + } +}; + +// Functor called by FakeQuantWithMinMaxVarsOp to do the work. Compiles both +// for CPU and GPU. +template +struct FakeQuantWithMinMaxVarsFunctor { + void operator()(const Device& d, ConstFlat inputs, + ConstScalar min, ConstScalar max, +#ifndef FAKE_QUANT_NO_DEBUG + Scalar check_min_max, +#endif + Flat outputs) { +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max.device(d) = (min <= 0.0f).all(); + eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise"); + check_min_max.device(d) = (max >= 0.0f).all(); + eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise"); + check_min_max.device(d) = (min < max).all(); + eigen_assert(check_min_max() && "min should be < max coeff-wise"); +#endif + + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = + Nudge(min(), max()); + const auto nudged_scale_repl = inputs.constant(nudged_scale); + + const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); + const auto clamped_shifted = clamped - nudged_min; + outputs.device(d) = (clamped_shifted / nudged_scale_repl + 0.5f).floor() * + nudged_scale_repl + nudged_min; + } +}; + +// Functor called by FakeQuantWithMinMaxVarsGradientOp to do the work. Compiles +// both for CPU and GPU. +template +struct FakeQuantWithMinMaxVarsGradientFunctor { + void operator()(const Device& d, + ConstFlat gradients, ConstFlat inputs, + ConstScalar min, ConstScalar max, +#ifndef FAKE_QUANT_NO_DEBUG + Scalar check_min_max, +#endif + Flat backprops_wrt_input, + Scalar backprop_wrt_min, + Scalar backprop_wrt_max) { +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max.device(d) = (min <= 0.0f).all(); + eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise"); + check_min_max.device(d) = (max >= 0.0f).all(); + eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise"); + check_min_max.device(d) = (min < max).all(); + eigen_assert(check_min_max() && "min should be < max coeff-wise"); +#endif + + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = + Nudge(min(), max()); + + const auto between_min_max = (inputs >= nudged_min && inputs <= nudged_max) + .select(inputs.constant(1.0f), inputs.constant(0.0f)); + backprops_wrt_input.device(d) = gradients * between_min_max; + + const auto below_min = (inputs < nudged_min) + .select(inputs.constant(1.0f), inputs.constant(0.0f)); + backprop_wrt_min.device(d) = (gradients * below_min).sum(); + + const auto above_max = (inputs > nudged_max) + .select(inputs.constant(1.0f), inputs.constant(0.0f)); + backprop_wrt_max.device(d) = (gradients * above_max).sum(); + } +}; + +using Index = typename tensorflow::TTypes::ConstTensor::Index; + +// Functor called by FakeQuantWithMinMaxVarsPerChannelOp to do the work. +// Compiles both for CPU and GPU. +// +// Already verified: inputs, outputs, min, max are of shape [d]. +template +struct FakeQuant1WithMinMaxVarsPerChannelFunctor { + void operator()(const Device& d, ConstVec inputs, + ConstVec min, ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + Scalar check_min_max, +#endif + Vec outputs) { +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max.device(d) = (min <= 0.0f).all(); + eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise"); + check_min_max.device(d) = (max >= 0.0f).all(); + eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise"); + check_min_max.device(d) = (min < max).all(); + eigen_assert(check_min_max() && "min should be < max coeff-wise"); +#endif + + for (Index i = 0; i < min.size(); ++i) { + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = + Nudge(min(i), max(i)); + const float clamped = + std::max(std::min(inputs(i), nudged_max), nudged_min); + const float clamped_shifted = clamped - nudged_min; + + outputs(i) = std::round(clamped_shifted / nudged_scale) * nudged_scale + + nudged_min; + } + } +}; + +// Already verified: inputs, outputs are of shape [b, d], min, max are of shape +// [d]. +template +struct FakeQuant2WithMinMaxVarsPerChannelFunctor { + void operator()(const Device& d, const Index batch_size, const Index depth, + ConstFlat inputs, + ConstVec min, ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + Scalar check_min_max, +#endif + Flat outputs) { +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max.device(d) = (min <= 0.0f).all(); + eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise"); + check_min_max.device(d) = (max >= 0.0f).all(); + eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise"); + check_min_max.device(d) = (min < max).all(); + eigen_assert(check_min_max() && "min should be < max coeff-wise"); +#endif + + Eigen::DSizes restored(batch_size, depth); + const auto inputs_restored = inputs.reshape(restored); + for (Index i = 0; i < min.size(); ++i) { + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = + Nudge(min(i), max(i)); + const auto clamped = inputs_restored.chip<1>(i) + .cwiseMin(nudged_max).cwiseMax(nudged_min); + const auto clamped_shifted = clamped - nudged_min; + + outputs.reshape(restored).chip<1>(i).device(d) = + (clamped_shifted / nudged_scale + 0.5f).floor() * nudged_scale + + nudged_min; + } + } +}; + +// Already verified: inputs, outputs are of shape [b, h, w, d], min, max are +// of shape [d]. +template +struct FakeQuant4WithMinMaxVarsPerChannelFunctor { + void operator()(const Device& d, const Index batch_size, const Index height, + const Index width, const Index depth, + ConstFlat inputs, + ConstVec min, ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + Scalar check_min_max, +#endif + Flat outputs) { +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max.device(d) = (min <= 0.0f).all(); + eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise"); + check_min_max.device(d) = (max >= 0.0f).all(); + eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise"); + check_min_max.device(d) = (min < max).all(); + eigen_assert(check_min_max() && "min should be < max coeff-wise"); +#endif + + Eigen::DSizes restored(batch_size, height, width, depth); + const auto inputs_restored = inputs.reshape(restored); + for (Index i = 0; i < min.size(); ++i) { + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = + Nudge(min(i), max(i)); + const auto clamped = inputs_restored.chip<3>(i) + .cwiseMin(nudged_max).cwiseMax(nudged_min); + const auto clamped_shifted = clamped - nudged_min; + + outputs.reshape(restored).chip<3>(i).device(d) = + (clamped_shifted / nudged_scale + 0.5f).floor() * nudged_scale + + nudged_min; + } + } +}; + +// Functor called by FakeQuantWithMinMaxVarsPerChannelGradientOp to do the work. +// Compiles both for CPU and GPU. +// +// Already verified: gradients, inputs, outputs, min, max, backprops_wrt_input, +// backprop_wrt_min, backprop_wrt_max are of shape [d]. +template +struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor { + void operator()(const Device& d, + ConstVec gradients, ConstVec inputs, + ConstVec min, ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + Scalar check_min_max, +#endif + Vec backprops_wrt_input, Vec backprop_wrt_min, + Vec backprop_wrt_max) { +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max.device(d) = (min <= 0.0f).all(); + eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise"); + check_min_max.device(d) = (max >= 0.0f).all(); + eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise"); + check_min_max.device(d) = (min < max).all(); + eigen_assert(check_min_max() && "min should be < max coeff-wise"); +#endif + + for (Index i = 0; i < min.size(); ++i) { + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = + Nudge(min(i), max(i)); + + const bool between_min_max = + inputs(i) >= nudged_min && inputs(i) <= nudged_max; + backprops_wrt_input(i) = between_min_max ? gradients(i) : 0.0f; + + const bool below_min = inputs(i) < nudged_min; + backprop_wrt_min(i) = below_min ? gradients(i) : 0.0f; + + const bool above_max = inputs(i) > nudged_max; + backprop_wrt_max(i) = above_max ? gradients(i) : 0.0f; + } + } +}; + +// Already verified: gradients, inputs, backprops_wrt_input are of shape [b, d], +// min, max, backprop_wrt_min, backprop_wrt_max are of shape [d]. +template +struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor { + void operator()(const Device& d, const Index batch_size, const Index depth, + ConstFlat gradients, ConstFlat inputs, + ConstVec min, ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + Scalar check_min_max, +#endif + Flat backprops_wrt_input, + Vec backprop_wrt_min, Vec backprop_wrt_max) { +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max.device(d) = (min <= 0.0f).all(); + eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise"); + check_min_max.device(d) = (max >= 0.0f).all(); + eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise"); + check_min_max.device(d) = (min < max).all(); + eigen_assert(check_min_max() && "min should be < max coeff-wise"); +#endif + + Eigen::DSizes restored(batch_size, depth); + const auto gradients_restored = gradients.reshape(restored); + const auto inputs_restored = inputs.reshape(restored); + for (Index i = 0; i < min.size(); ++i) { + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = + Nudge(min(i), max(i)); + const auto gradients_chip = gradients_restored.chip<1>(i); + const auto inputs_chip = inputs_restored.chip<1>(i); + + const auto between_min_max = + (inputs_chip >= nudged_min && inputs_chip <= nudged_max) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + backprops_wrt_input.reshape(restored).chip<1>(i).device(d) = + gradients_chip * between_min_max; + + const auto below_min = (inputs_chip < nudged_min) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + Eigen::DSizes reduce(0); + backprop_wrt_min.chip<0>(i).device(d) = + (gradients_chip * below_min).sum(reduce); + + const auto above_max = (inputs_chip > nudged_max) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + backprop_wrt_max.chip<0>(i).device(d) = + (gradients_chip * above_max).sum(reduce); + } + } +}; + +// Already verified: gradients, inputs, backprops_wrt_input are of shape +// [b, h, w, d], min, max, backprop_wrt_min, backprop_wrt_max are of shape [d]. +template +struct FakeQuant4WithMinMaxVarsPerChannelGradientFunctor { + void operator()(const Device& d, const Index batch_size, const Index height, + const Index width, const Index depth, + ConstFlat gradients, ConstFlat inputs, + ConstVec min, ConstVec max, +#ifndef FAKE_QUANT_NO_DEBUG + Scalar check_min_max, +#endif + Flat backprops_wrt_input, + Vec backprop_wrt_min, Vec backprop_wrt_max) { +#ifndef FAKE_QUANT_NO_DEBUG + check_min_max.device(d) = (min <= 0.0f).all(); + eigen_assert(check_min_max() && "min should be <= 0.0 coeff-wise"); + check_min_max.device(d) = (max >= 0.0f).all(); + eigen_assert(check_min_max() >= 0.0f && "max should be >= 0.0 coeff-wise"); + check_min_max.device(d) = (min < max).all(); + eigen_assert(check_min_max() && "min should be < max coeff-wise"); +#endif + + Eigen::DSizes restored(batch_size, height, width, depth); + const auto gradients_restored = gradients.reshape(restored); + const auto inputs_restored = inputs.reshape(restored); + for (Index i = 0; i < min.size(); ++i) { + float nudged_min, nudged_max, nudged_scale; + std::tie(nudged_min, nudged_max, nudged_scale) = + Nudge(min(i), max(i)); + const auto gradients_chip = gradients_restored.chip<3>(i); + const auto inputs_chip = inputs_restored.chip<3>(i); + + const auto between_min_max = + (inputs_chip >= nudged_min && inputs_chip <= nudged_max) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + backprops_wrt_input.reshape(restored).chip<3>(i).device(d) = + gradients_chip * between_min_max; + + const auto below_min = (inputs_chip < nudged_min) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + Eigen::DSizes reduce(0, 1, 2); + backprop_wrt_min.chip<0>(i).device(d) = + (gradients_chip * below_min).sum(reduce); + + const auto above_max = (inputs_chip > nudged_max) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + backprop_wrt_max.chip<0>(i).device(d) = + (gradients_chip * above_max).sum(reduce); + } + } +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_ diff --git a/tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc b/tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc new file mode 100644 index 00000000000..ad327937877 --- /dev/null +++ b/tensorflow/core/kernels/fake_quant_ops_gpu.cu.cc @@ -0,0 +1,41 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define FAKE_QUANT_NO_DEBUG + +#define EIGEN_USE_GPU +#include "tensorflow/core/kernels/fake_quant_ops_functor.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// Just instantiate GPU functor implementations. +template struct FakeQuantWithMinMaxArgsFunctor; +template struct FakeQuantWithMinMaxArgsGradientFunctor; +template struct FakeQuantWithMinMaxVarsFunctor; +template struct FakeQuantWithMinMaxVarsGradientFunctor; +template struct FakeQuant1WithMinMaxVarsPerChannelFunctor; +template struct FakeQuant2WithMinMaxVarsPerChannelFunctor; +template struct FakeQuant4WithMinMaxVarsPerChannelFunctor; +template struct FakeQuant1WithMinMaxVarsPerChannelGradientFunctor; +template struct FakeQuant2WithMinMaxVarsPerChannelGradientFunctor; +template struct FakeQuant4WithMinMaxVarsPerChannelGradientFunctor; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/fake_quant_ops_test.cc b/tensorflow/core/kernels/fake_quant_ops_test.cc new file mode 100644 index 00000000000..38ad345f0d3 --- /dev/null +++ b/tensorflow/core/kernels/fake_quant_ops_test.cc @@ -0,0 +1,821 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_testutil.h" + +namespace tensorflow { + +using tensorflow::AllocatorAttributes; +using tensorflow::DT_FLOAT; +using tensorflow::NodeDefBuilder; +using tensorflow::OpsTestBase; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::test::ExpectClose; +using tensorflow::test::FillValues; + +class QuantOpsTest : public OpsTestBase { + protected: + void AddRandomInput(const TensorShape& shape) { + CHECK_GT(input_types_.size(), inputs_.size()) + << "Adding more inputs than types; perhaps you need to call MakeOp"; + Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), + DT_FLOAT, shape); + input->flat().setRandom(); + tensors_.push_back(input); + bool is_ref = IsRefType(input_types_[inputs_.size()]); + if (is_ref) { + CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), DT_FLOAT); + inputs_.push_back({&lock_for_refs_, input}); + } else { + CHECK_EQ(input_types_[inputs_.size()], DT_FLOAT); + inputs_.push_back({nullptr, input}); + } + } +}; + +TEST_F(QuantOpsTest, WithArgsNoNudging) { + // Original quantization range: [-10 + 0 / 4, -10 + 255 / 4], scale: 1/4. + // Original zero point: 40, no nudging necessary. + // Expected quantized values: -10.0, -10.25, ..., 53.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs") + .Input(FakeInput(DT_FLOAT)) // inputs + .Attr("min", -10.0f) + .Attr("max", 53.75f) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, + {-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithArgsNudgedZeroIs0) { + // Original quantization range: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.4, nudged to 0. + // Nudged range: [0.0; 63.75]. + // Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs") + .Input(FakeInput(DT_FLOAT)) // inputs + .Attr("min", -0.1f) + .Attr("max", 63.65f) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.1f, 0.0f, 0.1f, 0.25f, 63.75f, 63.8f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, {0.0f, 0.0f, 0.0f, 0.25f, 63.75f, 63.75f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithArgsNudgedZeroIs1) { + // Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged range: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs") + .Input(FakeInput(DT_FLOAT)) // inputs + .Attr("min", -0.125f) + .Attr("max", 63.625f) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, {-0.25f, -0.25f, -0.25f, 0.0f, 63.5f, 63.5f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithArgsNudgedZeroIs255) { + // Original quantization range: [0.4 / 4 - 255 / 4, 0.4 / 4 + 0 / 4]. + // Scale: 1/4, original zero point: 254.6, nudged to 255. + // Nudged range: [-63.75; 0.0]. + // Expected quantized values: -63.75, -63.5, -63.25, ..., 0.0. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgs") + .Input(FakeInput(DT_FLOAT)) // inputs + .Attr("min", -63.65f) + .Attr("max", 0.1f) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-63.8f, -63.75f, -63.7f, -63.5f, 0.0f, 0.1f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, {-63.75f, -63.75f, -63.75f, -63.5f, 0.0f, 0.0f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithArgsGradient) { + // Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged range: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxArgsGradient") + .Input(FakeInput(DT_FLOAT)) // gradient + .Input(FakeInput(DT_FLOAT)) // inputs + .Attr("min", -0.125f) + .Attr("max", 63.625f) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Upstream gradients. + AddRandomInput(TensorShape({2, 3})); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + auto input_flat = GetInput(0).flat(); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, + {0.0f, input_flat(1), input_flat(2), + input_flat(3), input_flat(4), 0.0f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsNoNudging) { + // Original quantization range: [-10 + 0 / 4, -10 + 255 / 4], scale: 1/4. + // Original zero point: 40, no nudging necessary. + // Expected quantized values: -10.0, -10.25, ..., 53.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f}); + // Min. + AddInputFromArray(TensorShape({}), {-10.0f}); + // Max. + AddInputFromArray(TensorShape({}), {53.75f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, + {-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsNudgedZeroIs0) { + // Original quantization range: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.4, nudged to 0. + // Nudged range: [0.0; 63.75]. + // Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.1f, 0.0f, 0.1f, 0.25f, 63.75f, 63.8f}); + // Min. + AddInputFromArray(TensorShape({}), {-0.1f}); + // Max. + AddInputFromArray(TensorShape({}), {63.65f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, + {0.0f, 0.0f, 0.0f, 0.25f, 63.75f, 63.75f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsNudgedZeroIs1) { + // Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged range: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVars") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f}); + // Min. + AddInputFromArray(TensorShape({}), {-0.125f}); + // Max. + AddInputFromArray(TensorShape({}), {63.625f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, + {-0.25f, -0.25f, -0.25f, 0.0f, 63.5f, 63.5f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsGradient) { + // Original quantization range: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged range: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsGradient") + .Input(FakeInput(DT_FLOAT)) // gradients + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Upstream gradients. + AddRandomInput(TensorShape({2, 3})); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.26f, -0.25f, -0.24f, 0.0f, 63.5f, 63.6f}); + // Min. + AddInputFromArray(TensorShape({}), {-0.125f}); + // Max. + AddInputFromArray(TensorShape({}), {63.625f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output_bprop_wrt_input = GetOutput(0); + Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3})); + auto in_flat = GetInput(0).flat(); + FillValues(&expected_bprop_wrt_input, + {0.0f, in_flat(1), + in_flat(2), in_flat(3), + in_flat(4), 0.0f}); + ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); + + Tensor* output_bprop_wrt_min = GetOutput(1); + Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({})); + expected_bprop_wrt_min.flat()(0) = in_flat(0); + ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); + + Tensor* output_bprop_wrt_max = GetOutput(2); + Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({})); + expected_bprop_wrt_max.flat()(0) = in_flat(5); + ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim1NudgedZeroIs0) { + // Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.4, nudged to 0. + // Nudged ranges: [0.0; 63.75]. + // Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({4}), {-0.1f, 0.0f, 63.75f, 63.8f}); + // Min. + AddInputFromArray(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f}); + // Max. + AddInputFromArray(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected, {0.0f, 0.0f, 63.75f, 63.75f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim1NudgedZeroIs1) { + // Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged ranges: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({4}), {-0.26f, -0.25f, -0.24f, 63.6f}); + // Min. + AddInputFromArray(TensorShape({4}), + {-0.125f, -0.125f, -0.125f, -0.125f}); + // Max. + AddInputFromArray(TensorShape({4}), + {63.625f, 63.625f, 63.625f, 63.625f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected, {-0.25f, -0.25f, -0.25f, 63.5f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim2NudgedZeroIs0) { + // Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.4, nudged to 0. + // Nudged ranges: [0.0; 63.75]. + // Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.1f, 0.0f, 0.1f, + 0.25f, 63.75f, 63.8f}); + // Min. + AddInputFromArray(TensorShape({3}), {-0.1f, -0.1f, -0.1f}); + // Max. + AddInputFromArray(TensorShape({3}), {63.65f, 63.65f, 63.65f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, {0.0f, 0.0f, 0.0f, + 0.25f, 63.75f, 63.75f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim2NudgedZeroIs1) { + // Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged ranges: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.26f, -0.25f, -0.24f, + 0.0f, 63.5f, 63.6f}); + // Min. + AddInputFromArray(TensorShape({3}), {-0.125f, -0.125f, -0.125f}); + // Max. + AddInputFromArray(TensorShape({3}), {63.625f, 63.625f, 63.625f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + FillValues(&expected, {-0.25f, -0.25f, -0.25f, + 0.0f, 63.5f, 63.5f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim4NudgedZeroIs0) { + // Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.4, nudged to 0. + // Nudged ranges: [0.0; 63.75]. + // Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({1, 2, 3, 4}), + {-0.1f, 0.0f, 0.1f, 0.25f, + 0.5f, 0.75f, 1.0f, 1.25f, + 1.5f, 1.75f, 2.0f, 2.25f, + + 63.0f, 63.25f, 63.5f, 63.7f, + 63.75f, 63.8f, 63.9f, 100.0f, + 100.0f, 100.0f, 100.0f, 1000.0f}); + // Min. + AddInputFromArray(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f}); + // Max. + AddInputFromArray(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4})); + FillValues(&expected, + {0.0f, 0.0f, 0.0f, 0.25f, + 0.5f, 0.75f, 1.0f, 1.25f, + 1.5f, 1.75f, 2.0f, 2.25f, + + 63.0f, 63.25f, 63.5f, 63.75f, + 63.75f, 63.75f, 63.75f, 63.75f, + 63.75f, 63.75f, 63.75f, 63.75f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim4NudgedZeroIs1) { + // Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged ranges: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannel") + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Downstream inputs. + AddInputFromArray(TensorShape({1, 2, 3, 4}), + {-0.3f, -0.25f, -0.2f, 0.0f, + 0.25f, 0.5f, 0.75f, 1.0f, + 1.25f, 1.5f, 1.75f, 2.0f, + + 63.0f, 63.25f, 63.4f, 63.5f, + 63.6f, 63.7f, 100.0f, 100.0f, + 100.0f, 100.0f, 100.0f, 1000.0f}); + // Min. + AddInputFromArray(TensorShape({4}), + {-0.125f, -0.125f, -0.125f, -0.125f}); + // Max. + AddInputFromArray(TensorShape({4}), + {63.625f, 63.625f, 63.625f, 63.625f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4})); + FillValues(&expected, + {-0.25f, -0.25f, -0.25f, 0.0f, + 0.25f, 0.5f, 0.75f, 1.0f, + 1.25f, 1.5f, 1.75f, 2.0f, + + 63.0f, 63.25f, 63.5f, 63.5f, + 63.5f, 63.5f, 63.5f, 63.5f, + 63.5f, 63.5f, 63.5f, 63.5f}); + ExpectClose(expected, *output); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim1GradientNudgedZeroIs0) { + // Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.4, nudged to 0. + // Nudged ranges: [0.0; 63.75]. + // Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient") + .Input(FakeInput(DT_FLOAT)) // gradients + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Upstream gradients. + AddRandomInput(TensorShape({4})); + // Downstream inputs. + AddInputFromArray(TensorShape({4}), {-0.1f, 0.0f, 63.75f, 63.8f}); + // Min. + AddInputFromArray(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f}); + // Max. + AddInputFromArray(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output_bprop_wrt_input = GetOutput(0); + Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({4})); + auto grad_flat = GetInput(0).flat(); + FillValues(&expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), 0.0f}); + ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); + + Tensor* output_bprop_wrt_min = GetOutput(1); + Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected_bprop_wrt_min, + {grad_flat(0), 0.0f, 0.0f, 0.0f}); + ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); + + Tensor* output_bprop_wrt_max = GetOutput(2); + Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected_bprop_wrt_max, + {0.0f, 0.0f, 0.0f, grad_flat(3)}); + ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim1GradientNudgedZeroIs1) { + // Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged ranges: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient") + .Input(FakeInput(DT_FLOAT)) // gradients + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Upstream gradients. + AddRandomInput(TensorShape({4})); + // Downstream inputs. + AddInputFromArray(TensorShape({4}), {-0.3f, -0.25f, 63.5f, 63.6f}); + // Min. + AddInputFromArray(TensorShape({4}), + {-0.125f, -0.125f, -0.125f, -0.125f}); + // Max. + AddInputFromArray(TensorShape({4}), + {63.625f, 63.625f, 63.625f, 63.625f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output_bprop_wrt_input = GetOutput(0); + Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({4})); + auto grad_flat = GetInput(0).flat(); + FillValues(&expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), 0.0f}); + ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); + + Tensor* output_bprop_wrt_min = GetOutput(1); + Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected_bprop_wrt_min, + {grad_flat(0), 0.0f, 0.0f, 0.0f}); + ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); + + Tensor* output_bprop_wrt_max = GetOutput(2); + Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected_bprop_wrt_max, + {0.0f, 0.0f, 0.0f, grad_flat(3)}); + ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim2GradientNudgedZeroIs0) { + // Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.4, nudged to 0. + // Nudged ranges: [0.0; 63.75]. + // Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient") + .Input(FakeInput(DT_FLOAT)) // gradients + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Upstream gradients. + AddRandomInput(TensorShape({2, 3})); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.1f, 0.0f, 0.1f, + 0.25f, 63.75f, 63.8f}); + // Min. + AddInputFromArray(TensorShape({3}), {-0.1f, -0.1f, -0.1f}); + // Max. + AddInputFromArray(TensorShape({3}), {63.65f, 63.65f, 63.65f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output_bprop_wrt_input = GetOutput(0); + Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3})); + auto grad_flat = GetInput(0).flat(); + FillValues(&expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), + grad_flat(3), grad_flat(4), 0.0f}); + ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); + + Tensor* output_bprop_wrt_min = GetOutput(1); + Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({3})); + FillValues(&expected_bprop_wrt_min, + {grad_flat(0), 0.0f, 0.0f}); + ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); + + Tensor* output_bprop_wrt_max = GetOutput(2); + Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({3})); + FillValues(&expected_bprop_wrt_max, + {0.0f, 0.0f, grad_flat(5)}); + ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim2GradientNudgedZeroIs1) { + // Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged ranges: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient") + .Input(FakeInput(DT_FLOAT)) // gradients + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Upstream gradients. + AddRandomInput(TensorShape({2, 3})); + // Downstream inputs. + AddInputFromArray(TensorShape({2, 3}), + {-0.3f, -0.25f, -0.2f, + 0.0f, 63.5f, 63.6f}); + // Min. + AddInputFromArray(TensorShape({3}), {-0.125f, -0.125f, -0.125f}); + // Max. + AddInputFromArray(TensorShape({3}), {63.625f, 63.625f, 63.625f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output_bprop_wrt_input = GetOutput(0); + Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, TensorShape({2, 3})); + auto grad_flat = GetInput(0).flat(); + FillValues(&expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), + grad_flat(3), grad_flat(4), 0.0f}); + ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); + + Tensor* output_bprop_wrt_min = GetOutput(1); + Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({3})); + FillValues(&expected_bprop_wrt_min, + {grad_flat(0), 0.0f, 0.0f}); + ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); + + Tensor* output_bprop_wrt_max = GetOutput(2); + Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({3})); + FillValues(&expected_bprop_wrt_max, + {0.0f, 0.0f, grad_flat(5)}); + ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim4GradientNudgedZeroIs0) { + // Original quantization ranges: [-0.4 / 4 + 0 / 4, -0.4 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.4, nudged to 0. + // Nudged ranges: [0.0; 63.75]. + // Expected quantized values: 0.0, 0.25, 0.5, ..., 63.75. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient") + .Input(FakeInput(DT_FLOAT)) // gradients + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Upstream gradients. + AddRandomInput(TensorShape({1, 2, 3, 4})); + // Downstream inputs. + AddInputFromArray(TensorShape({1, 2, 3, 4}), + {-0.1f, 0.0f, 63.75f, 63.8f, + -0.1f, 0.0f, 63.75f, 63.8f, + -0.1f, 0.0f, 63.75f, 63.8f, + + -0.1f, 0.0f, 63.75f, 63.8f, + -0.1f, 0.0f, 63.75f, 63.8f, + -0.1f, 0.0f, 63.75f, 63.8f}); + // Min. + AddInputFromArray(TensorShape({4}), {-0.1f, -0.1f, -0.1f, -0.1f}); + // Max. + AddInputFromArray(TensorShape({4}), {63.65f, 63.65f, 63.65f, 63.65f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output_bprop_wrt_input = GetOutput(0); + Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, + TensorShape({1, 2, 3, 4})); + auto grad_flat = GetInput(0).flat(); + FillValues( + &expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), 0.0f, + 0.0f, grad_flat(5), grad_flat(6), 0.0f, + 0.0f, grad_flat(9), grad_flat(10), 0.0f, + + 0.0f, grad_flat(13), grad_flat(14), 0.0f, + 0.0f, grad_flat(17), grad_flat(18), 0.0f, + 0.0f, grad_flat(21), grad_flat(22), 0.0f}); + ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); + + Tensor* output_bprop_wrt_min = GetOutput(1); + Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected_bprop_wrt_min, + {grad_flat(0) + grad_flat(4) + grad_flat(8) + + grad_flat(12) + grad_flat(16) + grad_flat(20), + 0.0f, 0.0f, 0.0f}); + ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); + + Tensor* output_bprop_wrt_max = GetOutput(2); + Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected_bprop_wrt_max, + {0.0f, 0.0f, 0.0f, + grad_flat(3) + grad_flat(7) + grad_flat(11) + + grad_flat(15) + grad_flat(19) + grad_flat(23)}); + ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); +} + +TEST_F(QuantOpsTest, WithVarsPerChannelDim4GradientNudgedZeroIs1) { + // Original quantization ranges: [-0.5 / 4 + 0 / 4, -0.5 / 4 + 255 / 4]. + // Scale: 1/4, original zero point: 0.5, nudged to 1. + // Nudged ranges: [-0.25; 63.5]. + // Expected quantized values: -0.25, 0.0, 0.25, ..., 63.5. + TF_EXPECT_OK(NodeDefBuilder("op", "FakeQuantWithMinMaxVarsPerChannelGradient") + .Input(FakeInput(DT_FLOAT)) // gradients + .Input(FakeInput(DT_FLOAT)) // inputs + .Input(FakeInput(DT_FLOAT)) // min + .Input(FakeInput(DT_FLOAT)) // max + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + // Upstream gradients. + AddRandomInput(TensorShape({1, 2, 3, 4})); + // Downstream inputs. + AddInputFromArray(TensorShape({1, 2, 3, 4}), + {-0.3f, -0.25f, 63.5f, 63.6f, + -0.3f, -0.25f, 63.5f, 63.6f, + -0.3f, -0.25f, 63.5f, 63.6f, + + -0.3f, -0.25f, 63.5f, 63.6f, + -0.3f, -0.25f, 63.5f, 63.6f, + -0.3f, -0.25f, 63.5f, 63.6f}); + // Min. + AddInputFromArray(TensorShape({4}), + {-0.125f, -0.125f, -0.125f, -0.125f}); + // Max. + AddInputFromArray(TensorShape({4}), + {63.625f, 63.625f, 63.625f, 63.625f}); + + // Tested code. + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output_bprop_wrt_input = GetOutput(0); + Tensor expected_bprop_wrt_input(allocator(), DT_FLOAT, + TensorShape({1, 2, 3, 4})); + auto grad_flat = GetInput(0).flat(); + FillValues(&expected_bprop_wrt_input, + {0.0f, grad_flat(1), grad_flat(2), 0.0f, + 0.0f, grad_flat(5), grad_flat(6), 0.0f, + 0.0f, grad_flat(9), grad_flat(10), 0.0f, + + 0.0f, grad_flat(13), grad_flat(14), 0.0f, + 0.0f, grad_flat(17), grad_flat(18), 0.0f, + 0.0f, grad_flat(21), grad_flat(22), 0.0f}); + ExpectClose(expected_bprop_wrt_input, *output_bprop_wrt_input); + + Tensor* output_bprop_wrt_min = GetOutput(1); + Tensor expected_bprop_wrt_min(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected_bprop_wrt_min, + {grad_flat(0) + grad_flat(4) + grad_flat(8) + + grad_flat(12) + grad_flat(16) + grad_flat(20), + 0.0f, 0.0f, 0.0f}); + ExpectClose(expected_bprop_wrt_min, *output_bprop_wrt_min); + + Tensor* output_bprop_wrt_max = GetOutput(2); + Tensor expected_bprop_wrt_max(allocator(), DT_FLOAT, TensorShape({4})); + FillValues(&expected_bprop_wrt_max, + {0.0f, 0.0f, 0.0f, + grad_flat(3) + grad_flat(7) + grad_flat(11) + + grad_flat(15) + grad_flat(19) + grad_flat(23)}); + ExpectClose(expected_bprop_wrt_max, *output_bprop_wrt_max); +} + +} // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index cdf9fd4341f..b1b553ec8c2 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -4383,6 +4383,117 @@ output_min: This value is copied from input_min. output_max: This value is copied from input_max. )Doc"); +REGISTER_OP("FakeQuantWithMinMaxArgs") + .Attr("min: float = -6.0") + .Attr("max: float = 6.0") + .Input("inputs: float") + .Output("outputs: float") + .Doc(R"doc( +Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. + +Attributes [min; max] define the clamping range for the 'inputs' data. Op +divides this range into 255 steps (total of 256 values), then replaces each +'inputs' value with the closest of the quantized step values. + +Quantization is called fake since the output is still in floating point. +)doc"); + +REGISTER_OP("FakeQuantWithMinMaxArgsGradient") + .Attr("min: float = -6.0") + .Attr("max: float = 6.0") + .Input("gradients: float") + .Input("inputs: float") + .Output("backprops: float") + .Doc(R"doc( +Compute gradients for a FakeQuantWithMinMaxArgs operation. + +gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. +inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. +backprops: Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: + `gradients * (inputs >= min && inputs <= max)`. +)doc"); + +REGISTER_OP("FakeQuantWithMinMaxVars") + .Input("inputs: float") + .Input("min: float") + .Input("max: float") + .Output("outputs: float") + .Doc(R"doc( +Fake-quantize the 'inputs' tensor of type float and shape `[b, h, w, d]` via +global float scalars `min` and `max` to 'outputs' tensor of same shape as +`inputs`. + +[min; max] is the clamping range for the 'inputs' data. Op divides this range +into 255 steps (total of 256 values), then replaces each 'inputs' value with the +closest of the quantized step values. + +This operation has a gradient and thus allows for training `min` and `max` values. +)doc"); + +REGISTER_OP("FakeQuantWithMinMaxVarsGradient") + .Input("gradients: float") + .Input("inputs: float") + .Input("min: float") + .Input("max: float") + .Output("backprops_wrt_input: float") + .Output("backprop_wrt_min: float") + .Output("backprop_wrt_max: float") + .Doc(R"doc( +Compute gradients for a FakeQuantWithMinMaxVars operation. + +gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. +inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. +min, max: Quantization interval, scalar floats. +backprops_wrt_input: Backpropagated gradients w.r.t. inputs: + `gradients * (inputs >= min && inputs <= max)`. +backprop_wrt_min: Backpropagated gradients w.r.t. min parameter: + `sum(gradients * (inputs < min))`. +backprop_wrt_max: Backpropagated gradients w.r.t. max parameter: + `sum(gradients * (inputs > max))`. +)doc"); + +REGISTER_OP("FakeQuantWithMinMaxVarsPerChannel") + .Input("inputs: float") + .Input("min: float") + .Input("max: float") + .Output("outputs: float") + .Doc(R"doc( +Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`, +`[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` +to 'outputs' tensor of same shape as `inputs`. + +[min; max] is the clamping range for the 'inputs' data in the corresponding +depth channel. Op divides this range into 255 steps (total of 256 values), then +replaces each 'inputs' value with the closest of the quantized step values. + +This operation has a gradient and thus allows for training `min` and `max` values. +)doc"); + +REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient") + .Input("gradients: float") + .Input("inputs: float") + .Input("min: float") + .Input("max: float") + .Output("backprops_wrt_input: float") + .Output("backprop_wrt_min: float") + .Output("backprop_wrt_max: float") + .Doc(R"doc( +Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation. + +gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation, + shape one of: `[d]`, `[b, d]`, `[b, h, w, d]`. +inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation, shape + same as `gradients`. +min, max: Quantization interval, floats of shape `[d]`. +backprops_wrt_input: Backpropagated gradients w.r.t. inputs, shape same as + `inputs`: + `gradients * (inputs >= min && inputs <= max)`. +backprop_wrt_min: Backpropagated gradients w.r.t. min parameter, shape `[d]`: + `sum_per_d(gradients * (inputs < min))`. +backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`: + `sum_per_d(gradients * (inputs > max))`. +)doc"); + // Deprecated op registrations: // The following can be deleted after 10mar2017. diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 6abce62ecc2..dcb57d7e0c3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1905,7 +1905,6 @@ def _EditDistanceShape(op): return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[2, 5]) -# The remaining ops do not change the shape of their inputs. @ops.RegisterShape("Quantize") @ops.RegisterShape("Dequantize") def _QuantizeDequantizeShape(op): @@ -1914,6 +1913,45 @@ def _QuantizeDequantizeShape(op): return common_shapes.unchanged_shape(op) +@ops.RegisterShape("FakeQuantWithMinMaxArgs") +def _FakeQuantWithMinMaxArgsShape(op): + """Shape function for FakeQuantWithMinMaxArgs op: preserve the input shape.""" + return [op.inputs[0].get_shape()] + + +@ops.RegisterGradient("FakeQuantWithMinMaxArgs") +def _FakeQuantWithMinMaxArgsGradient(op, grad): + """Gradient for FakeQuantWithMinMaxArgs op.""" + return fake_quant_with_min_max_args_gradient(grad, op.inputs[0]) + + +@ops.RegisterShape("FakeQuantWithMinMaxVars") +def _FakeQuantWithMinMaxVarsShape(op): + """Shape function for FakeQuantWithMinMaxVars op: preserve the input shape.""" + return [op.inputs[0].get_shape()] + + +@ops.RegisterGradient("FakeQuantWithMinMaxVars") +def _FakeQuantWithMinMaxVarsGradient(op, grad): + """Gradient for FakeQuantWithMinMaxVars op.""" + return fake_quant_with_min_max_vars_gradient(grad, op.inputs[0], op.inputs[1], + op.inputs[2]) + + +@ops.RegisterShape("FakeQuantWithMinMaxVarsPerChannel") +def _FakeQuantWithMinMaxVarsPerChannelShape(op): + """Shape function for FakeQuantWithMinMaxVarsPerChannel op: input shape.""" + return [op.inputs[0].get_shape()] + + +@ops.RegisterGradient("FakeQuantWithMinMaxVarsPerChannel") +def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad): + """Gradient for FakeQuantWithMinMaxVarsPerChannel op.""" + return fake_quant_with_min_max_vars_per_channel_gradient(grad, op.inputs[0], + op.inputs[1], + op.inputs[2]) + + ops.RegisterShape("ExtractImagePatches")(common_shapes.call_cpp_shape_fn)