Temporary rollback of QuantizeAndDequantizeV2Grad.
PiperOrigin-RevId: 296507138 Change-Id: Iab04845a3e2a760073aa3c88ef3e10272a2094c1
This commit is contained in:
parent
8a72c4466a
commit
8f04d92daf
|
@ -15,12 +15,13 @@ limitations under the License.
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
|
||||||
#include "tensorflow/cc/framework/gradients.h"
|
|
||||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/framework/gradients.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -89,25 +90,15 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
|
||||||
}
|
}
|
||||||
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
||||||
|
|
||||||
Status QuantizeAndDequantizeV2GradHelper(const Scope& scope,
|
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
|
||||||
const Operation& op,
|
const std::vector<Output>& grad_inputs,
|
||||||
const std::vector<Output>& grad_inputs,
|
std::vector<Output>* grad_outputs) {
|
||||||
std::vector<Output>* grad_outputs) {
|
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
|
||||||
Input input = Shape(scope, op.input(0));
|
grad_outputs->push_back(NoGradient());
|
||||||
Input input_min = op.input(1);
|
grad_outputs->push_back(NoGradient());
|
||||||
Input input_max = op.input(2);
|
|
||||||
int64 axis;
|
|
||||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
|
|
||||||
auto qdq_v2_grad = QuantizeAndDequantizeV2Grad(
|
|
||||||
scope, grad_inputs[0], input, input_min, input_max,
|
|
||||||
QuantizeAndDequantizeV2Grad::Axis(axis));
|
|
||||||
grad_outputs->push_back(qdq_v2_grad.input_backprop);
|
|
||||||
grad_outputs->push_back(qdq_v2_grad.input_min_backprop);
|
|
||||||
grad_outputs->push_back(qdq_v2_grad.input_max_backprop);
|
|
||||||
return scope.status();
|
return scope.status();
|
||||||
}
|
}
|
||||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2",
|
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
|
||||||
QuantizeAndDequantizeV2GradHelper);
|
|
||||||
|
|
||||||
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
||||||
const std::vector<Output>& grad_inputs,
|
const std::vector<Output>& grad_inputs,
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
op {
|
|
||||||
graph_op_name: "QuantizeAndDequantizeV2Grad"
|
|
||||||
summary: "Returns the gradient of `QuantizeAndDequantizeV2`."
|
|
||||||
description: <<END
|
|
||||||
Returns a gradient of 1 for inputs that are within the quantization range,
|
|
||||||
or 0 otherwise.
|
|
||||||
END
|
|
||||||
}
|
|
|
@ -1,3 +0,0 @@
|
||||||
op {
|
|
||||||
graph_op_name: "QuantizeAndDequantizeV2Grad"
|
|
||||||
}
|
|
|
@ -1,4 +0,0 @@
|
||||||
op {
|
|
||||||
graph_op_name: "QuantizeAndDequantizeV2Grad"
|
|
||||||
visibility: HIDDEN
|
|
||||||
}
|
|
|
@ -131,75 +131,6 @@ class QuantizeAndDequantizeV2Op : public OpKernel {
|
||||||
bool narrow_range_;
|
bool narrow_range_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Implementation of QuantizeAndDequantizeV2GradientOp.
|
|
||||||
// When back-propagating the error through a quantized layer, the following
|
|
||||||
// paper gives evidence that clipped-ReLU is better than non-clipped:
|
|
||||||
// "Deep Learning with Low Precision by Half-wave Gaussian Quantization"
|
|
||||||
// http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf
|
|
||||||
template <typename Device, typename T>
|
|
||||||
class QuantizeAndDequantizeV2GradientOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit QuantizeAndDequantizeV2GradientOp(OpKernelConstruction* ctx)
|
|
||||||
: OpKernel::OpKernel(ctx) {
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
|
|
||||||
}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
|
||||||
const Tensor& gradient = ctx->input(0);
|
|
||||||
const Tensor& input = ctx->input(1);
|
|
||||||
Tensor* input_backprop = nullptr;
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
ctx->allocate_output(0, input.shape(), &input_backprop));
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, input.IsSameSize(gradient),
|
|
||||||
errors::InvalidArgument("gradient and input must be the same size"));
|
|
||||||
const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
|
|
||||||
const Tensor& input_min_tensor = ctx->input(2);
|
|
||||||
const Tensor& input_max_tensor = ctx->input(3);
|
|
||||||
if (axis_ != -1) {
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, input_min_tensor.dim_size(0) == depth,
|
|
||||||
errors::InvalidArgument("min has incorrect size, expected ", depth,
|
|
||||||
" was ", input_min_tensor.dim_size(0)));
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, input_max_tensor.dim_size(0) == depth,
|
|
||||||
errors::InvalidArgument("max has incorrect size, expected ", depth,
|
|
||||||
" was ", input_max_tensor.dim_size(0)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorShape min_max_shape(input_min_tensor.shape());
|
|
||||||
Tensor* input_min_backprop;
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
ctx->allocate_output(1, min_max_shape, &input_min_backprop));
|
|
||||||
|
|
||||||
Tensor* input_max_backprop;
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
ctx->allocate_output(2, min_max_shape, &input_max_backprop));
|
|
||||||
|
|
||||||
if (axis_ == -1) {
|
|
||||||
functor::QuantizeAndDequantizeOneScaleGradientFunctor<Device, T> f;
|
|
||||||
f(ctx->eigen_device<Device>(), gradient.template flat<T>(),
|
|
||||||
input.template flat<T>(), input_min_tensor.scalar<T>(),
|
|
||||||
input_max_tensor.scalar<T>(), input_backprop->template flat<T>(),
|
|
||||||
input_min_backprop->template scalar<T>(),
|
|
||||||
input_max_backprop->template scalar<T>());
|
|
||||||
} else {
|
|
||||||
functor::QuantizeAndDequantizePerChannelGradientFunctor<Device, T> f;
|
|
||||||
f(ctx->eigen_device<Device>(),
|
|
||||||
gradient.template flat_inner_outer_dims<T, 3>(axis_ - 1),
|
|
||||||
input.template flat_inner_outer_dims<T, 3>(axis_ - 1),
|
|
||||||
&input_min_tensor, &input_max_tensor,
|
|
||||||
input_backprop->template flat_inner_outer_dims<T, 3>(axis_ - 1),
|
|
||||||
input_min_backprop->template flat<T>(),
|
|
||||||
input_max_backprop->template flat<T>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int axis_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Simulate quantization precision loss in a float tensor by:
|
// Simulate quantization precision loss in a float tensor by:
|
||||||
// 1. Quantize the tensor to fixed point numbers, which should match the target
|
// 1. Quantize the tensor to fixed point numbers, which should match the target
|
||||||
// quantization method when it is used in inference.
|
// quantization method when it is used in inference.
|
||||||
|
@ -364,43 +295,6 @@ struct QuantizeAndDequantizePerChannelFunctor<CPUDevice, T> {
|
||||||
input_max_tensor, round_mode, narrow_range, out);
|
input_max_tensor, round_mode, narrow_range, out);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice, T> {
|
|
||||||
void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat gradient,
|
|
||||||
typename TTypes<T>::ConstFlat input,
|
|
||||||
typename TTypes<T>::ConstScalar input_min_tensor,
|
|
||||||
typename TTypes<T>::ConstScalar input_max_tensor,
|
|
||||||
typename TTypes<T>::Flat input_backprop,
|
|
||||||
typename TTypes<T>::Scalar input_min_backprop,
|
|
||||||
typename TTypes<T>::Scalar input_max_backprop) {
|
|
||||||
QuantizeAndDequantizeOneScaleGradientImpl<CPUDevice, T>::Compute(
|
|
||||||
d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
|
|
||||||
input_min_backprop, input_max_backprop);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct QuantizeAndDequantizePerChannelGradientFunctor<CPUDevice, T> {
|
|
||||||
void operator()(const CPUDevice& d,
|
|
||||||
typename TTypes<T, 3>::ConstTensor gradient,
|
|
||||||
typename TTypes<T, 3>::ConstTensor input,
|
|
||||||
const Tensor* input_min_tensor,
|
|
||||||
const Tensor* input_max_tensor,
|
|
||||||
typename TTypes<T, 3>::Tensor input_backprop,
|
|
||||||
typename TTypes<T>::Flat input_min_backprop,
|
|
||||||
typename TTypes<T>::Flat input_max_backprop) {
|
|
||||||
QuantizeAndDequantizePerChannelGradientImpl<CPUDevice, T>::Compute(
|
|
||||||
d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
|
|
||||||
input_min_backprop, input_max_backprop);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice,
|
|
||||||
float>;
|
|
||||||
template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
|
|
||||||
CPUDevice, double>;
|
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
#define REGISTER_CPU_KERNEL(T) \
|
#define REGISTER_CPU_KERNEL(T) \
|
||||||
|
@ -408,10 +302,6 @@ template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T"), \
|
.TypeConstraint<T>("T"), \
|
||||||
QuantizeAndDequantizeV2Op<CPUDevice, T>); \
|
QuantizeAndDequantizeV2Op<CPUDevice, T>); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2Grad") \
|
|
||||||
.Device(DEVICE_CPU) \
|
|
||||||
.TypeConstraint<T>("T"), \
|
|
||||||
QuantizeAndDequantizeV2GradientOp<CPUDevice, T>); \
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \
|
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T"), \
|
.TypeConstraint<T>("T"), \
|
||||||
|
@ -432,12 +322,6 @@ TF_CALL_double(REGISTER_CPU_KERNEL);
|
||||||
.HostMemory("input_max") \
|
.HostMemory("input_max") \
|
||||||
.TypeConstraint<T>("T"), \
|
.TypeConstraint<T>("T"), \
|
||||||
QuantizeAndDequantizeV2Op<GPUDevice, T>); \
|
QuantizeAndDequantizeV2Op<GPUDevice, T>); \
|
||||||
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2Grad") \
|
|
||||||
.Device(DEVICE_GPU) \
|
|
||||||
.HostMemory("input_min") \
|
|
||||||
.HostMemory("input_max") \
|
|
||||||
.TypeConstraint<T>("T"), \
|
|
||||||
QuantizeAndDequantizeV2GradientOp<GPUDevice, T>); \
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \
|
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
.HostMemory("input_min") \
|
.HostMemory("input_min") \
|
||||||
|
|
|
@ -60,28 +60,6 @@ struct QuantizeAndDequantizePerChannelFunctor {
|
||||||
typename TTypes<T, 3>::Tensor output);
|
typename TTypes<T, 3>::Tensor output);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
|
||||||
struct QuantizeAndDequantizeOneScaleGradientFunctor {
|
|
||||||
void operator()(const Device& d, typename TTypes<T>::ConstFlat gradient,
|
|
||||||
typename TTypes<T>::ConstFlat input,
|
|
||||||
typename TTypes<T>::ConstScalar input_min,
|
|
||||||
typename TTypes<T>::ConstScalar input_max,
|
|
||||||
typename TTypes<T>::Flat input_backprop,
|
|
||||||
typename TTypes<T>::Scalar input_min_backprop,
|
|
||||||
typename TTypes<T>::Scalar input_max_backprop);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Device, typename T>
|
|
||||||
struct QuantizeAndDequantizePerChannelGradientFunctor {
|
|
||||||
void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor gradient,
|
|
||||||
typename TTypes<T, 3>::ConstTensor input,
|
|
||||||
const Tensor* input_min_tensor,
|
|
||||||
const Tensor* input_max_tensor,
|
|
||||||
typename TTypes<T, 3>::Tensor input_backprop,
|
|
||||||
typename TTypes<T>::Flat input_min_backprop,
|
|
||||||
typename TTypes<T>::Flat input_max_backprop);
|
|
||||||
};
|
|
||||||
|
|
||||||
// The implementation below runs on both CPU and GPU.
|
// The implementation below runs on both CPU and GPU.
|
||||||
template <typename Device, typename T, typename Func,
|
template <typename Device, typename T, typename Func,
|
||||||
typename Vec = typename TTypes<T>::Vec,
|
typename Vec = typename TTypes<T>::Vec,
|
||||||
|
@ -271,55 +249,6 @@ struct QuantizeAndDequantizePerChannelImpl {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
|
||||||
struct QuantizeAndDequantizeOneScaleGradientImpl {
|
|
||||||
static void Compute(const Device& d, typename TTypes<T>::ConstFlat gradient,
|
|
||||||
typename TTypes<T>::ConstFlat input,
|
|
||||||
typename TTypes<T>::ConstScalar input_min,
|
|
||||||
typename TTypes<T>::ConstScalar input_max,
|
|
||||||
typename TTypes<T>::Flat input_backprop,
|
|
||||||
typename TTypes<T>::Scalar input_min_backprop,
|
|
||||||
typename TTypes<T>::Scalar input_max_backprop) {
|
|
||||||
const T min_val = input_min();
|
|
||||||
const T max_val = input_max();
|
|
||||||
const auto in_range =
|
|
||||||
(input >= min_val && input <= max_val)
|
|
||||||
.select(input.constant(1.0f), input.constant(0.0f));
|
|
||||||
input_backprop.device(d) = gradient * in_range;
|
|
||||||
input_min_backprop.device(d) = input_min_backprop.constant(0.0f);
|
|
||||||
input_max_backprop.device(d) = input_max_backprop.constant(0.0f);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Device, typename T>
|
|
||||||
struct QuantizeAndDequantizePerChannelGradientImpl {
|
|
||||||
static void Compute(const Device& d,
|
|
||||||
typename TTypes<T, 3>::ConstTensor gradient,
|
|
||||||
typename TTypes<T, 3>::ConstTensor input,
|
|
||||||
const Tensor* input_min_tensor,
|
|
||||||
const Tensor* input_max_tensor,
|
|
||||||
typename TTypes<T, 3>::Tensor input_backprop,
|
|
||||||
typename TTypes<T>::Flat input_min_backprop,
|
|
||||||
typename TTypes<T>::Flat input_max_backprop) {
|
|
||||||
using Index = typename tensorflow::TTypes<T>::ConstTensor::Index;
|
|
||||||
auto input_min = input_min_tensor->vec<T>();
|
|
||||||
auto input_max = input_max_tensor->vec<T>();
|
|
||||||
int num_channels = input.dimension(1);
|
|
||||||
for (Index i = 0; i < num_channels; ++i) {
|
|
||||||
const auto gradient_chip = gradient.template chip<1>(i);
|
|
||||||
const auto input_chip = input.template chip<1>(i);
|
|
||||||
const T min_val = input_min(i);
|
|
||||||
const T max_val = input_max(i);
|
|
||||||
const auto in_range =
|
|
||||||
(input_chip >= min_val && input_chip <= max_val)
|
|
||||||
.select(input_chip.constant(1.0f), input_chip.constant(0.0f));
|
|
||||||
input_backprop.template chip<1>(i).device(d) = gradient_chip * in_range;
|
|
||||||
}
|
|
||||||
input_min_backprop.device(d) = input_min_backprop.constant(0.0f);
|
|
||||||
input_max_backprop.device(d) = input_max_backprop.constant(0.0f);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end of namespace functor
|
} // end of namespace functor
|
||||||
} // end of namespace tensorflow
|
} // end of namespace tensorflow
|
||||||
|
|
||||||
|
|
|
@ -53,37 +53,6 @@ struct QuantizeAndDequantizePerChannelFunctor<GPUDevice, T> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct QuantizeAndDequantizeOneScaleGradientFunctor<GPUDevice, T> {
|
|
||||||
void operator()(const GPUDevice& d, typename TTypes<T>::ConstFlat gradient,
|
|
||||||
typename TTypes<T>::ConstFlat input,
|
|
||||||
typename TTypes<T>::ConstScalar input_min_tensor,
|
|
||||||
typename TTypes<T>::ConstScalar input_max_tensor,
|
|
||||||
typename TTypes<T>::Flat input_backprop,
|
|
||||||
typename TTypes<T>::Scalar input_min_backprop,
|
|
||||||
typename TTypes<T>::Scalar input_max_backprop) {
|
|
||||||
QuantizeAndDequantizeOneScaleGradientImpl<GPUDevice, T>::Compute(
|
|
||||||
d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
|
|
||||||
input_min_backprop, input_max_backprop);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct QuantizeAndDequantizePerChannelGradientFunctor<GPUDevice, T> {
|
|
||||||
void operator()(const GPUDevice& d,
|
|
||||||
typename TTypes<T, 3>::ConstTensor gradient,
|
|
||||||
typename TTypes<T, 3>::ConstTensor input,
|
|
||||||
const Tensor* input_min_tensor,
|
|
||||||
const Tensor* input_max_tensor,
|
|
||||||
typename TTypes<T, 3>::Tensor input_backprop,
|
|
||||||
typename TTypes<T>::Flat input_min_backprop,
|
|
||||||
typename TTypes<T>::Flat input_max_backprop) {
|
|
||||||
QuantizeAndDequantizePerChannelGradientImpl<GPUDevice, T>::Compute(
|
|
||||||
d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
|
|
||||||
input_min_backprop, input_max_backprop);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end namespace functor
|
} // end namespace functor
|
||||||
|
|
||||||
// Instantiate the GPU implementation for float and double.
|
// Instantiate the GPU implementation for float and double.
|
||||||
|
@ -96,15 +65,6 @@ template struct functor::QuantizeAndDequantizePerChannelFunctor<GPUDevice,
|
||||||
template struct functor::QuantizeAndDequantizePerChannelFunctor<GPUDevice,
|
template struct functor::QuantizeAndDequantizePerChannelFunctor<GPUDevice,
|
||||||
double>;
|
double>;
|
||||||
|
|
||||||
template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<GPUDevice,
|
|
||||||
float>;
|
|
||||||
template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<GPUDevice,
|
|
||||||
double>;
|
|
||||||
template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
|
|
||||||
GPUDevice, float>;
|
|
||||||
template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
|
|
||||||
GPUDevice, double>;
|
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
|
@ -362,54 +362,6 @@ TEST_P(ParameterizedQuantizeAndDequantizeTest,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a 1D tensor with signed 8 bits and round_mode half_up.
|
|
||||||
TEST_P(ParameterizedQuantizeAndDequantizeTest, GradientV2_op) {
|
|
||||||
const int axis = GetParam();
|
|
||||||
TF_ASSERT_OK(NodeDefBuilder("qdq_v2_grad_op", "QuantizeAndDequantizeV2Grad")
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Attr("axis", axis)
|
|
||||||
.Finalize(node_def()));
|
|
||||||
TF_ASSERT_OK(InitOp());
|
|
||||||
const std::vector<int64> dims = {2, 3, 4, 5};
|
|
||||||
// Input gradient. (repeating 11 values multiplied by (slice_idx + 1))
|
|
||||||
auto gradients = ScalePerSliceAlongAxis<float>(
|
|
||||||
dims, axis, {1, -2, -3, 4, 5, 6, -7, -8, -9, -10, 11});
|
|
||||||
AddInputFromArray<float>(TensorShape(dims), gradients);
|
|
||||||
// Forward op inputs. (repeating 7 values multiplied by (slice_idx + 1)).
|
|
||||||
auto inputs = ScalePerSliceAlongAxis<float>(
|
|
||||||
dims, axis, {-1, -0.5, 0, 0.3, 0.8, 0.55, 0.6});
|
|
||||||
AddInputFromArray<float>(TensorShape(dims), inputs);
|
|
||||||
const int num_slices = (axis == -1) ? 1 : dims[axis];
|
|
||||||
const TensorShape range_shape =
|
|
||||||
(axis == -1) ? TensorShape({}) : TensorShape({num_slices});
|
|
||||||
std::vector<float> input_min_values(num_slices), input_max_values(num_slices);
|
|
||||||
for (int i = 0; i < num_slices; ++i) {
|
|
||||||
input_max_values[i] = 0.8f + i * 0.4f;
|
|
||||||
input_min_values[i] = -input_max_values[i];
|
|
||||||
}
|
|
||||||
AddInputFromArray<float>(range_shape, input_min_values);
|
|
||||||
AddInputFromArray<float>(range_shape, input_max_values);
|
|
||||||
std::vector<float> expected_vals(inputs.size());
|
|
||||||
int minor_size = 1;
|
|
||||||
for (int i = axis + 1; i < dims.size(); ++i) {
|
|
||||||
minor_size *= dims[i];
|
|
||||||
}
|
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
|
||||||
int slice_idx = (i / minor_size) % num_slices;
|
|
||||||
expected_vals[i] = ((inputs[i] >= input_min_values[slice_idx]) &&
|
|
||||||
(inputs[i] <= input_max_values[slice_idx]))
|
|
||||||
? gradients[i]
|
|
||||||
: 0;
|
|
||||||
}
|
|
||||||
TF_ASSERT_OK(RunOpKernel());
|
|
||||||
Tensor expected(allocator(), DT_FLOAT, TensorShape(dims));
|
|
||||||
test::FillValues<float>(&expected, expected_vals);
|
|
||||||
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Instantiate parameterized tests for axis = -1, 1, 3.
|
// Instantiate parameterized tests for axis = -1, 1, 3.
|
||||||
INSTANTIATE_TEST_SUITE_P(All, ParameterizedQuantizeAndDequantizeTest,
|
INSTANTIATE_TEST_SUITE_P(All, ParameterizedQuantizeAndDequantizeTest,
|
||||||
::testing::Values(-1, 1, 3));
|
::testing::Values(-1, 1, 3));
|
||||||
|
|
|
@ -2800,38 +2800,6 @@ REGISTER_OP("QuantizeAndDequantizeV2")
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
REGISTER_OP("QuantizeAndDequantizeV2Grad")
|
|
||||||
.Input("gradients: T")
|
|
||||||
.Input("input: T")
|
|
||||||
.Input("input_min: T")
|
|
||||||
.Input("input_max: T")
|
|
||||||
.Output("input_backprop: T")
|
|
||||||
.Output("input_min_backprop: T")
|
|
||||||
.Output("input_max_backprop: T")
|
|
||||||
.Attr("T: {bfloat16, half, float, double}")
|
|
||||||
.Attr("axis: int = -1")
|
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
|
||||||
int axis;
|
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
|
|
||||||
const int minmax_rank = (axis == -1) ? 0 : 1;
|
|
||||||
ShapeHandle minmax;
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
|
|
||||||
TF_RETURN_IF_ERROR(c->Merge(c->input(3), minmax, &minmax));
|
|
||||||
if (axis != -1) {
|
|
||||||
ShapeHandle input;
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
|
|
||||||
DimensionHandle depth;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
|
|
||||||
}
|
|
||||||
ShapeHandle inputs;
|
|
||||||
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
|
|
||||||
c->set_output(0, inputs);
|
|
||||||
c->set_output(1, minmax);
|
|
||||||
c->set_output(2, minmax);
|
|
||||||
return Status::OK();
|
|
||||||
});
|
|
||||||
|
|
||||||
REGISTER_OP("QuantizeAndDequantizeV3")
|
REGISTER_OP("QuantizeAndDequantizeV3")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Input("input_min: T")
|
.Input("input_min: T")
|
||||||
|
|
|
@ -224,7 +224,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||||
{"PreventGradient"},
|
{"PreventGradient"},
|
||||||
{"Qr"},
|
{"Qr"},
|
||||||
{"QuantizeAndDequantize"},
|
{"QuantizeAndDequantize"},
|
||||||
{"QuantizeAndDequantizeV2Grad", 1, {3}},
|
{"QuantizeAndDequantizeV2"},
|
||||||
{"QuantizeAndDequantizeV3"},
|
{"QuantizeAndDequantizeV3"},
|
||||||
{"QueueClose"},
|
{"QueueClose"},
|
||||||
{"QueueDequeue"},
|
{"QueueDequeue"},
|
||||||
|
@ -410,7 +410,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||||
|
|
||||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||||
const tensorflow::string &op_name) {
|
const tensorflow::string &op_name) {
|
||||||
static std::array<OpIndexInfo, 470> a = {{
|
static std::array<OpIndexInfo, 469> a = {{
|
||||||
{"Abs"},
|
{"Abs"},
|
||||||
{"AccumulateNV2"},
|
{"AccumulateNV2"},
|
||||||
{"Acos"},
|
{"Acos"},
|
||||||
|
@ -652,7 +652,6 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||||
{"Prod"},
|
{"Prod"},
|
||||||
{"QuantizeAndDequantize"},
|
{"QuantizeAndDequantize"},
|
||||||
{"QuantizeAndDequantizeV2"},
|
{"QuantizeAndDequantizeV2"},
|
||||||
{"QuantizeAndDequantizeV2Grad"},
|
|
||||||
{"QuantizeAndDequantizeV3"},
|
{"QuantizeAndDequantizeV3"},
|
||||||
{"QueueClose"},
|
{"QueueClose"},
|
||||||
{"QueueEnqueue"},
|
{"QueueEnqueue"},
|
||||||
|
|
|
@ -959,6 +959,11 @@ def _QuantizeAndDequantizeGrad(_, grad):
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("QuantizeAndDequantizeV2")
|
||||||
|
def _QuantizeAndDequantizeV2Grad(_, grad):
|
||||||
|
return [grad, None, None]
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("QuantizeAndDequantizeV3")
|
@ops.RegisterGradient("QuantizeAndDequantizeV3")
|
||||||
def _QuantizeAndDequantizeV3Grad(_, grad):
|
def _QuantizeAndDequantizeV3Grad(_, grad):
|
||||||
# Only propagate the gradient for the unquantized input.
|
# Only propagate the gradient for the unquantized input.
|
||||||
|
|
|
@ -3551,23 +3551,6 @@ def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
|
||||||
narrow_range=op.get_attr("narrow_range"))
|
narrow_range=op.get_attr("narrow_range"))
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("QuantizeAndDequantizeV2")
|
|
||||||
def _QuantizeAndDequantizeV2Grad(op, grad):
|
|
||||||
"""Gradient for QuantizeAndDequantizeV2 op."""
|
|
||||||
return quantize_and_dequantize_v2_grad(
|
|
||||||
grad,
|
|
||||||
op.inputs[0],
|
|
||||||
op.inputs[1],
|
|
||||||
op.inputs[2],
|
|
||||||
axis=op.get_attr("axis"))
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("QuantizeAndDequantizeV2Grad")
|
|
||||||
def _QuantizeAndDequantizeV2GradGrad(op, grad):
|
|
||||||
"""Gradient for QuantizeAndDequantizeV2Grad op."""
|
|
||||||
return _QuantizeAndDequantizeV2Grad(op, grad)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("required_space_to_batch_paddings")
|
@tf_export("required_space_to_batch_paddings")
|
||||||
def required_space_to_batch_paddings(input_shape,
|
def required_space_to_batch_paddings(input_shape,
|
||||||
block_shape,
|
block_shape,
|
||||||
|
|
|
@ -2784,10 +2784,6 @@ tf_module {
|
||||||
name: "QuantizeAndDequantizeV2"
|
name: "QuantizeAndDequantizeV2"
|
||||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], "
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "QuantizeAndDequantizeV2Grad"
|
|
||||||
argspec: "args=[\'gradients\', \'input\', \'input_min\', \'input_max\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "QuantizeAndDequantizeV3"
|
name: "QuantizeAndDequantizeV3"
|
||||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'num_bits\', \'signed_input\', \'range_given\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'False\', \'-1\', \'None\'], "
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'num_bits\', \'signed_input\', \'range_given\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'False\', \'-1\', \'None\'], "
|
||||||
|
|
|
@ -2784,10 +2784,6 @@ tf_module {
|
||||||
name: "QuantizeAndDequantizeV2"
|
name: "QuantizeAndDequantizeV2"
|
||||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], "
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "QuantizeAndDequantizeV2Grad"
|
|
||||||
argspec: "args=[\'gradients\', \'input\', \'input_min\', \'input_max\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "QuantizeAndDequantizeV3"
|
name: "QuantizeAndDequantizeV3"
|
||||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'num_bits\', \'signed_input\', \'range_given\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'False\', \'-1\', \'None\'], "
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'num_bits\', \'signed_input\', \'range_given\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'False\', \'-1\', \'None\'], "
|
||||||
|
|
Loading…
Reference in New Issue