Create a V2 Op to stop the gradient when the input is out of range.
PiperOrigin-RevId: 336692325 Change-Id: I36fd3fcfc58a30d5218beca512fbfc7c24b8b5cb
This commit is contained in:
parent
d5bfcbf8b6
commit
52df91c563
RELEASE.md
tensorflow
cc/gradients
compiler/tests
core
api_def
base_api
java_api
python_api
kernels
quantize_and_dequantize_op.ccquantize_and_dequantize_op.hquantize_and_dequantize_op_gpu.cu.ccquantize_and_dequantize_op_test.cc
ops
python
tools/api/golden
@ -47,6 +47,11 @@
|
||||
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
|
||||
instead of individual arguments. Usages should be updated to
|
||||
`tf.data.experimental.service.WorkerServer(worker_config)`.
|
||||
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
|
||||
updates the gradient definition for quantization which is outside the range
|
||||
to be 0. To simulate the V1 the behavior of
|
||||
tf.quantization.quantize_and_dequantize(...) use
|
||||
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
|
||||
|
||||
## Known Caveats
|
||||
|
||||
|
@ -15,13 +15,12 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
||||
|
||||
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(NoGradient());
|
||||
Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
|
||||
const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
Input input = Shape(scope, op.input(0));
|
||||
Input input_min = op.input(1);
|
||||
Input input_max = op.input(2);
|
||||
int64 axis;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
|
||||
auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
|
||||
scope, grad_inputs[0], input, input_min, input_max,
|
||||
QuantizeAndDequantizeV4Grad::Axis(axis));
|
||||
grad_outputs->push_back(qdq_v4_grad.input_backprop);
|
||||
grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
|
||||
grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
|
||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
|
||||
QuantizeAndDequantizeV4GradHelper);
|
||||
|
||||
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
|
@ -542,7 +542,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
for dtype in self.float_types:
|
||||
|
||||
def quantize_and_dequantize_v2(x):
|
||||
return array_ops.quantize_and_dequantize_v2(
|
||||
return array_ops.quantize_and_dequantize(
|
||||
x, -127, 127, signed_input=True, num_bits=8)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
@ -551,7 +551,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
|
||||
|
||||
def quantize_and_dequantize_v2_round_half_up(x):
|
||||
return array_ops.quantize_and_dequantize_v2(
|
||||
return array_ops.quantize_and_dequantize(
|
||||
x,
|
||||
-1,
|
||||
1.0,
|
||||
@ -575,7 +575,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
dtype=dtype))
|
||||
|
||||
def quantize_and_dequantize_v2_round_half_to_even(x):
|
||||
return array_ops.quantize_and_dequantize_v2(
|
||||
return array_ops.quantize_and_dequantize(
|
||||
x,
|
||||
-1.0,
|
||||
1.0,
|
||||
|
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "QuantizeAndDequantizeV4"
|
||||
summary: "Returns the gradient of `QuantizeAndDequantizeV4`."
|
||||
description: <<END
|
||||
This is almost identical to QuantizeAndDequantizeV2, except that it returns a
|
||||
gradient of 1 for inputs that are within the quantization range, or 0 otherwise.
|
||||
END
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "QuantizeAndDequantizeV4Grad"
|
||||
summary: "Returns the gradient of `QuantizeAndDequantizeV4`."
|
||||
description: <<END
|
||||
Returns a gradient of 1 for inputs that are within the quantization range,
|
||||
or 0 otherwise.
|
||||
END
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
op {
|
||||
graph_op_name: "QuantizeAndDequantizeV4Grad"
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
op {
|
||||
graph_op_name: "QuantizeAndDequantizeV4Grad"
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "QuantizeAndDequantizeV4Grad"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "QuantizeAndDequantizeV4Grad"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -131,6 +131,75 @@ class QuantizeAndDequantizeV2Op : public OpKernel {
|
||||
bool narrow_range_;
|
||||
};
|
||||
|
||||
// Implementation of QuantizeAndDequantizeV4GradientOp.
|
||||
// When back-propagating the error through a quantized layer, the following
|
||||
// paper gives evidence that clipped-ReLU is better than non-clipped:
|
||||
// "Deep Learning with Low Precision by Half-wave Gaussian Quantization"
|
||||
// http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf
|
||||
template <typename Device, typename T>
|
||||
class QuantizeAndDequantizeV4GradientOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantizeAndDequantizeV4GradientOp(OpKernelConstruction* ctx)
|
||||
: OpKernel::OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& gradient = ctx->input(0);
|
||||
const Tensor& input = ctx->input(1);
|
||||
Tensor* input_backprop = nullptr;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, input.shape(), &input_backprop));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, input.IsSameSize(gradient),
|
||||
errors::InvalidArgument("gradient and input must be the same size"));
|
||||
const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
|
||||
const Tensor& input_min_tensor = ctx->input(2);
|
||||
const Tensor& input_max_tensor = ctx->input(3);
|
||||
if (axis_ != -1) {
|
||||
OP_REQUIRES(
|
||||
ctx, input_min_tensor.dim_size(0) == depth,
|
||||
errors::InvalidArgument("min has incorrect size, expected ", depth,
|
||||
" was ", input_min_tensor.dim_size(0)));
|
||||
OP_REQUIRES(
|
||||
ctx, input_max_tensor.dim_size(0) == depth,
|
||||
errors::InvalidArgument("max has incorrect size, expected ", depth,
|
||||
" was ", input_max_tensor.dim_size(0)));
|
||||
}
|
||||
|
||||
TensorShape min_max_shape(input_min_tensor.shape());
|
||||
Tensor* input_min_backprop;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(1, min_max_shape, &input_min_backprop));
|
||||
|
||||
Tensor* input_max_backprop;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(2, min_max_shape, &input_max_backprop));
|
||||
|
||||
if (axis_ == -1) {
|
||||
functor::QuantizeAndDequantizeOneScaleGradientFunctor<Device, T> f;
|
||||
f(ctx->eigen_device<Device>(), gradient.template flat<T>(),
|
||||
input.template flat<T>(), input_min_tensor.scalar<T>(),
|
||||
input_max_tensor.scalar<T>(), input_backprop->template flat<T>(),
|
||||
input_min_backprop->template scalar<T>(),
|
||||
input_max_backprop->template scalar<T>());
|
||||
} else {
|
||||
functor::QuantizeAndDequantizePerChannelGradientFunctor<Device, T> f;
|
||||
f(ctx->eigen_device<Device>(),
|
||||
gradient.template flat_inner_outer_dims<T, 3>(axis_ - 1),
|
||||
input.template flat_inner_outer_dims<T, 3>(axis_ - 1),
|
||||
&input_min_tensor, &input_max_tensor,
|
||||
input_backprop->template flat_inner_outer_dims<T, 3>(axis_ - 1),
|
||||
input_min_backprop->template flat<T>(),
|
||||
input_max_backprop->template flat<T>());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
};
|
||||
|
||||
// Simulate quantization precision loss in a float tensor by:
|
||||
// 1. Quantize the tensor to fixed point numbers, which should match the target
|
||||
// quantization method when it is used in inference.
|
||||
@ -295,6 +364,43 @@ struct QuantizeAndDequantizePerChannelFunctor<CPUDevice, T> {
|
||||
input_max_tensor, round_mode, narrow_range, out);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat gradient,
|
||||
typename TTypes<T>::ConstFlat input,
|
||||
typename TTypes<T>::ConstScalar input_min_tensor,
|
||||
typename TTypes<T>::ConstScalar input_max_tensor,
|
||||
typename TTypes<T>::Flat input_backprop,
|
||||
typename TTypes<T>::Scalar input_min_backprop,
|
||||
typename TTypes<T>::Scalar input_max_backprop) {
|
||||
QuantizeAndDequantizeOneScaleGradientImpl<CPUDevice, T>::Compute(
|
||||
d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
|
||||
input_min_backprop, input_max_backprop);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct QuantizeAndDequantizePerChannelGradientFunctor<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d,
|
||||
typename TTypes<T, 3>::ConstTensor gradient,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Tensor* input_min_tensor,
|
||||
const Tensor* input_max_tensor,
|
||||
typename TTypes<T, 3>::Tensor input_backprop,
|
||||
typename TTypes<T>::Flat input_min_backprop,
|
||||
typename TTypes<T>::Flat input_max_backprop) {
|
||||
QuantizeAndDequantizePerChannelGradientImpl<CPUDevice, T>::Compute(
|
||||
d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
|
||||
input_min_backprop, input_max_backprop);
|
||||
}
|
||||
};
|
||||
|
||||
template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<CPUDevice,
|
||||
float>;
|
||||
template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
|
||||
CPUDevice, double>;
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER_CPU_KERNEL(T) \
|
||||
@ -306,6 +412,14 @@ struct QuantizeAndDequantizePerChannelFunctor<CPUDevice, T> {
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
QuantizeAndDequantizeV3Op<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
QuantizeAndDequantizeV2Op<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
QuantizeAndDequantizeV4GradientOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
QuantizeAndDequantizeOp<CPUDevice, T>);
|
||||
@ -329,6 +443,18 @@ TF_CALL_double(REGISTER_CPU_KERNEL);
|
||||
.HostMemory("num_bits") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
QuantizeAndDequantizeV3Op<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("input_min") \
|
||||
.HostMemory("input_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
QuantizeAndDequantizeV2Op<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("input_min") \
|
||||
.HostMemory("input_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
QuantizeAndDequantizeV4GradientOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
QuantizeAndDequantizeOp<GPUDevice, T>);
|
||||
|
@ -60,6 +60,28 @@ struct QuantizeAndDequantizePerChannelFunctor {
|
||||
typename TTypes<T, 3>::Tensor output);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct QuantizeAndDequantizeOneScaleGradientFunctor {
|
||||
void operator()(const Device& d, typename TTypes<T>::ConstFlat gradient,
|
||||
typename TTypes<T>::ConstFlat input,
|
||||
typename TTypes<T>::ConstScalar input_min,
|
||||
typename TTypes<T>::ConstScalar input_max,
|
||||
typename TTypes<T>::Flat input_backprop,
|
||||
typename TTypes<T>::Scalar input_min_backprop,
|
||||
typename TTypes<T>::Scalar input_max_backprop);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct QuantizeAndDequantizePerChannelGradientFunctor {
|
||||
void operator()(const Device& d, typename TTypes<T, 3>::ConstTensor gradient,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Tensor* input_min_tensor,
|
||||
const Tensor* input_max_tensor,
|
||||
typename TTypes<T, 3>::Tensor input_backprop,
|
||||
typename TTypes<T>::Flat input_min_backprop,
|
||||
typename TTypes<T>::Flat input_max_backprop);
|
||||
};
|
||||
|
||||
// The implementation below runs on both CPU and GPU.
|
||||
template <typename Device, typename T, typename Func,
|
||||
typename Vec = typename TTypes<T>::Vec,
|
||||
@ -249,6 +271,55 @@ struct QuantizeAndDequantizePerChannelImpl {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct QuantizeAndDequantizeOneScaleGradientImpl {
|
||||
static void Compute(const Device& d, typename TTypes<T>::ConstFlat gradient,
|
||||
typename TTypes<T>::ConstFlat input,
|
||||
typename TTypes<T>::ConstScalar input_min,
|
||||
typename TTypes<T>::ConstScalar input_max,
|
||||
typename TTypes<T>::Flat input_backprop,
|
||||
typename TTypes<T>::Scalar input_min_backprop,
|
||||
typename TTypes<T>::Scalar input_max_backprop) {
|
||||
const T min_val = input_min();
|
||||
const T max_val = input_max();
|
||||
const auto in_range =
|
||||
(input >= min_val && input <= max_val)
|
||||
.select(input.constant(1.0f), input.constant(0.0f));
|
||||
input_backprop.device(d) = gradient * in_range;
|
||||
input_min_backprop.device(d) = input_min_backprop.constant(0.0f);
|
||||
input_max_backprop.device(d) = input_max_backprop.constant(0.0f);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct QuantizeAndDequantizePerChannelGradientImpl {
|
||||
static void Compute(const Device& d,
|
||||
typename TTypes<T, 3>::ConstTensor gradient,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Tensor* input_min_tensor,
|
||||
const Tensor* input_max_tensor,
|
||||
typename TTypes<T, 3>::Tensor input_backprop,
|
||||
typename TTypes<T>::Flat input_min_backprop,
|
||||
typename TTypes<T>::Flat input_max_backprop) {
|
||||
using Index = typename tensorflow::TTypes<T>::ConstTensor::Index;
|
||||
auto input_min = input_min_tensor->vec<T>();
|
||||
auto input_max = input_max_tensor->vec<T>();
|
||||
int num_channels = input.dimension(1);
|
||||
for (Index i = 0; i < num_channels; ++i) {
|
||||
const auto gradient_chip = gradient.template chip<1>(i);
|
||||
const auto input_chip = input.template chip<1>(i);
|
||||
const T min_val = input_min(i);
|
||||
const T max_val = input_max(i);
|
||||
const auto in_range =
|
||||
(input_chip >= min_val && input_chip <= max_val)
|
||||
.select(input_chip.constant(1.0f), input_chip.constant(0.0f));
|
||||
input_backprop.template chip<1>(i).device(d) = gradient_chip * in_range;
|
||||
}
|
||||
input_min_backprop.device(d) = input_min_backprop.constant(0.0f);
|
||||
input_max_backprop.device(d) = input_max_backprop.constant(0.0f);
|
||||
}
|
||||
};
|
||||
|
||||
} // end of namespace functor
|
||||
} // end of namespace tensorflow
|
||||
|
||||
|
@ -53,6 +53,37 @@ struct QuantizeAndDequantizePerChannelFunctor<GPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct QuantizeAndDequantizeOneScaleGradientFunctor<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T>::ConstFlat gradient,
|
||||
typename TTypes<T>::ConstFlat input,
|
||||
typename TTypes<T>::ConstScalar input_min_tensor,
|
||||
typename TTypes<T>::ConstScalar input_max_tensor,
|
||||
typename TTypes<T>::Flat input_backprop,
|
||||
typename TTypes<T>::Scalar input_min_backprop,
|
||||
typename TTypes<T>::Scalar input_max_backprop) {
|
||||
QuantizeAndDequantizeOneScaleGradientImpl<GPUDevice, T>::Compute(
|
||||
d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
|
||||
input_min_backprop, input_max_backprop);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct QuantizeAndDequantizePerChannelGradientFunctor<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d,
|
||||
typename TTypes<T, 3>::ConstTensor gradient,
|
||||
typename TTypes<T, 3>::ConstTensor input,
|
||||
const Tensor* input_min_tensor,
|
||||
const Tensor* input_max_tensor,
|
||||
typename TTypes<T, 3>::Tensor input_backprop,
|
||||
typename TTypes<T>::Flat input_min_backprop,
|
||||
typename TTypes<T>::Flat input_max_backprop) {
|
||||
QuantizeAndDequantizePerChannelGradientImpl<GPUDevice, T>::Compute(
|
||||
d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
|
||||
input_min_backprop, input_max_backprop);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
// Instantiate the GPU implementation for float and double.
|
||||
@ -65,6 +96,15 @@ template struct functor::QuantizeAndDequantizePerChannelFunctor<GPUDevice,
|
||||
template struct functor::QuantizeAndDequantizePerChannelFunctor<GPUDevice,
|
||||
double>;
|
||||
|
||||
template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<GPUDevice,
|
||||
float>;
|
||||
template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<GPUDevice,
|
||||
double>;
|
||||
template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
|
||||
GPUDevice, float>;
|
||||
template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
|
||||
GPUDevice, double>;
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -362,6 +362,54 @@ TEST_P(ParameterizedQuantizeAndDequantizeTest,
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies the Gradient.
|
||||
TEST_P(ParameterizedQuantizeAndDequantizeTest, GradientV4_op) {
|
||||
const int axis = GetParam();
|
||||
TF_ASSERT_OK(NodeDefBuilder("qdq_v4_grad_op", "QuantizeAndDequantizeV4Grad")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("axis", axis)
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
const std::vector<int64> dims = {2, 3, 4, 5};
|
||||
// Input gradient. (repeating 11 values multiplied by (slice_idx + 1))
|
||||
auto gradients = ScalePerSliceAlongAxis<float>(
|
||||
dims, axis, {1, -2, -3, 4, 5, 6, -7, -8, -9, -10, 11});
|
||||
AddInputFromArray<float>(TensorShape(dims), gradients);
|
||||
// Forward op inputs. (repeating 7 values multiplied by (slice_idx + 1)).
|
||||
auto inputs = ScalePerSliceAlongAxis<float>(
|
||||
dims, axis, {-1, -0.5, 0, 0.3, 0.8, 0.55, 0.6});
|
||||
AddInputFromArray<float>(TensorShape(dims), inputs);
|
||||
const int num_slices = (axis == -1) ? 1 : dims[axis];
|
||||
const TensorShape range_shape =
|
||||
(axis == -1) ? TensorShape({}) : TensorShape({num_slices});
|
||||
std::vector<float> input_min_values(num_slices), input_max_values(num_slices);
|
||||
for (int i = 0; i < num_slices; ++i) {
|
||||
input_max_values[i] = 0.8f + i * 0.4f;
|
||||
input_min_values[i] = -input_max_values[i];
|
||||
}
|
||||
AddInputFromArray<float>(range_shape, input_min_values);
|
||||
AddInputFromArray<float>(range_shape, input_max_values);
|
||||
std::vector<float> expected_vals(inputs.size());
|
||||
int minor_size = 1;
|
||||
for (int i = axis + 1; i < dims.size(); ++i) {
|
||||
minor_size *= dims[i];
|
||||
}
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
int slice_idx = (i / minor_size) % num_slices;
|
||||
expected_vals[i] = ((inputs[i] >= input_min_values[slice_idx]) &&
|
||||
(inputs[i] <= input_max_values[slice_idx]))
|
||||
? gradients[i]
|
||||
: 0;
|
||||
}
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape(dims));
|
||||
test::FillValues<float>(&expected, expected_vals);
|
||||
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
|
||||
}
|
||||
|
||||
// Instantiate parameterized tests for axis = -1, 1, 3.
|
||||
INSTANTIATE_TEST_SUITE_P(All, ParameterizedQuantizeAndDequantizeTest,
|
||||
::testing::Values(-1, 1, 3));
|
||||
|
@ -2808,6 +2808,70 @@ REGISTER_OP("QuantizeAndDequantizeV2")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("QuantizeAndDequantizeV4")
|
||||
.Input("input: T")
|
||||
.Input("input_min: T")
|
||||
.Input("input_max: T")
|
||||
.Attr("signed_input: bool = true")
|
||||
.Attr("num_bits: int = 8")
|
||||
.Attr("range_given: bool = false")
|
||||
.Output("output: T")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.Attr(
|
||||
"round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
|
||||
"'HALF_TO_EVEN'")
|
||||
.Attr("narrow_range: bool = false")
|
||||
.Attr("axis: int = -1")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
int axis;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
|
||||
const int minmax_rank = (axis == -1) ? 0 : 1;
|
||||
ShapeHandle minmax;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax));
|
||||
if (axis != -1) {
|
||||
ShapeHandle input;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
|
||||
DimensionHandle depth;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
|
||||
}
|
||||
c->set_output(0, c->input(0));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("QuantizeAndDequantizeV4Grad")
|
||||
.Input("gradients: T")
|
||||
.Input("input: T")
|
||||
.Input("input_min: T")
|
||||
.Input("input_max: T")
|
||||
.Output("input_backprop: T")
|
||||
.Output("input_min_backprop: T")
|
||||
.Output("input_max_backprop: T")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.Attr("axis: int = -1")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
int axis;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
|
||||
const int minmax_rank = (axis == -1) ? 0 : 1;
|
||||
ShapeHandle minmax;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(3), minmax, &minmax));
|
||||
if (axis != -1) {
|
||||
ShapeHandle input;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
|
||||
DimensionHandle depth;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
|
||||
}
|
||||
ShapeHandle inputs;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
|
||||
c->set_output(0, inputs);
|
||||
c->set_output(1, minmax);
|
||||
c->set_output(2, minmax);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("QuantizeAndDequantizeV3")
|
||||
.Input("input: T")
|
||||
.Input("input_min: T")
|
||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 357> a = {{
|
||||
static std::array<OpIndexInfo, 358> a = {{
|
||||
{"Acosh"},
|
||||
{"AllToAll", 1, {0}},
|
||||
{"ApproximateEqual"},
|
||||
@ -227,6 +227,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
{"QuantizeAndDequantize"},
|
||||
{"QuantizeAndDequantizeV2"},
|
||||
{"QuantizeAndDequantizeV3"},
|
||||
{"QuantizeAndDequantizeV4Grad", 1, {3}},
|
||||
{"QueueClose"},
|
||||
{"QueueDequeue"},
|
||||
{"QueueDequeueMany"},
|
||||
@ -420,7 +421,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 473> a = {{
|
||||
static std::array<OpIndexInfo, 475> a = {{
|
||||
{"Abs"},
|
||||
{"AccumulateNV2"},
|
||||
{"Acos"},
|
||||
@ -669,6 +670,8 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
{"QuantizeAndDequantize"},
|
||||
{"QuantizeAndDequantizeV2"},
|
||||
{"QuantizeAndDequantizeV3"},
|
||||
{"QuantizeAndDequantizeV4"},
|
||||
{"QuantizeAndDequantizeV4Grad"},
|
||||
{"QueueClose"},
|
||||
{"QueueEnqueue"},
|
||||
{"QueueEnqueueMany"},
|
||||
|
@ -1610,7 +1610,7 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
|
||||
expected = self._scale_per_slice(shape, axis, quant_values)
|
||||
unused_minmax_value = 0 if axis is None else [0] * shape[axis]
|
||||
fake_quantized = self.evaluate(
|
||||
array_ops.quantize_and_dequantize(
|
||||
array_ops.quantize_and_dequantize_v2(
|
||||
inputs,
|
||||
unused_minmax_value,
|
||||
unused_minmax_value,
|
||||
@ -1620,7 +1620,7 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(fake_quantized, expected)
|
||||
if axis is not None:
|
||||
fake_quantized = self.evaluate(
|
||||
array_ops.quantize_and_dequantize(
|
||||
array_ops.quantize_and_dequantize_v2(
|
||||
inputs,
|
||||
unused_minmax_value,
|
||||
unused_minmax_value,
|
||||
@ -1628,6 +1628,23 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase):
|
||||
axis=(axis - 4)))
|
||||
self.assertAllClose(fake_quantized, expected)
|
||||
|
||||
def testQuantizeDequantizeGrad(self):
|
||||
shape = (2, 2)
|
||||
max_threshold = 0
|
||||
min_threshold = -10
|
||||
input_value = np.random.rand(2, 2) * 40.0 - 20.0
|
||||
input_tensor = constant_op.constant(input_value, shape=shape,
|
||||
name="input_tensor")
|
||||
with self.cached_session():
|
||||
def f(a):
|
||||
return array_ops.quantize_and_dequantize_v2(
|
||||
a,
|
||||
input_min=min_threshold,
|
||||
input_max=max_threshold,
|
||||
range_given=True)
|
||||
output_grad = gradient_checker_v2.compute_gradient(f, [input_tensor])
|
||||
self.assertAllClose(output_grad[0], np.zeros([1, 4, 4]))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SortedSearchTest(test_util.TensorFlowTestCase):
|
||||
|
@ -3758,6 +3758,23 @@ def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
|
||||
narrow_range=op.get_attr("narrow_range"))
|
||||
|
||||
|
||||
@ops.RegisterGradient("QuantizeAndDequantizeV4")
|
||||
def _QuantizeAndDequantizeV4Grad(op, grad):
|
||||
"""Gradient for QuantizeAndDequantizeV4 op."""
|
||||
return quantize_and_dequantize_v4_grad(
|
||||
grad,
|
||||
op.inputs[0],
|
||||
op.inputs[1],
|
||||
op.inputs[2],
|
||||
axis=op.get_attr("axis"))
|
||||
|
||||
|
||||
@ops.RegisterGradient("QuantizeAndDequantizeV4Grad")
|
||||
def _QuantizeAndDequantizeV4GradGrad(op, grad):
|
||||
"""Gradient for QuantizeAndDequantizeV4Grad op."""
|
||||
return _QuantizeAndDequantizeV4Grad(op, grad)
|
||||
|
||||
|
||||
@tf_export("required_space_to_batch_paddings")
|
||||
def required_space_to_batch_paddings(input_shape,
|
||||
block_shape,
|
||||
@ -5630,6 +5647,13 @@ dequantize.__doc__ = gen_array_ops.dequantize.__doc__
|
||||
|
||||
@tf_export("quantization.quantize_and_dequantize")
|
||||
@dispatch.add_dispatch_support
|
||||
@deprecation.deprecated(None,
|
||||
"This Op has been deprecated, use" +
|
||||
"`quantize_and_dequantize_v2` instead. To " +
|
||||
"To simulate the V1 the behavior of " +
|
||||
"tf.quantization.quantize_and_dequantize(...) use " +
|
||||
"tf.grad_pass_through(" +
|
||||
"tf.quantization.quantize_and_dequantize_v2)(...).")
|
||||
def quantize_and_dequantize(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
input_min,
|
||||
@ -5688,6 +5712,93 @@ def quantize_and_dequantize(
|
||||
name=name)
|
||||
|
||||
|
||||
@tf_export("quantization.quantize_and_dequantize_v2")
|
||||
@dispatch.add_dispatch_support
|
||||
def quantize_and_dequantize_v2(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
input_min,
|
||||
input_max,
|
||||
signed_input=True,
|
||||
num_bits=8,
|
||||
range_given=False,
|
||||
round_mode="HALF_TO_EVEN",
|
||||
name=None,
|
||||
narrow_range=False,
|
||||
axis=None):
|
||||
"""Quantizes then dequantizes a tensor.
|
||||
|
||||
Updates the gradient definition for quantization that is outside the range to
|
||||
be 0.To simulate the V1 the behavior of
|
||||
tf.quantization.quantize_and_dequantize(...) use
|
||||
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
def getQuantizeOp(input):
|
||||
input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
|
||||
net = tf.quantization.quantize_and_dequantize(input,
|
||||
input_min=min_threshold,
|
||||
input_max=max_threshold,
|
||||
range_given=True)
|
||||
|
||||
To simulate v1 behavior:
|
||||
|
||||
def testDecomposeQuantizeDequantize(self):
|
||||
def f(input_tensor):
|
||||
return tf.quantization.quantize_and_dequantize_v2(input_tensor,
|
||||
input_min = 5.0,
|
||||
input_max= -10.0,
|
||||
range_given=True)
|
||||
input_tensor = tf.placeholder(tf.float32, shape=[4, 4])
|
||||
net = tf.grad_pass_through(f)(input_tensor)
|
||||
```
|
||||
|
||||
Args:
|
||||
input: A `Tensor` to quantize and dequantize.
|
||||
input_min: If range_given=True, the minimum input value, that needs to be
|
||||
represented in the quantized representation. If axis is specified, this
|
||||
should be a vector of minimum values for each slice along axis.
|
||||
input_max: If range_given=True, the maximum input value that needs to be
|
||||
represented in the quantized representation. If axis is specified, this
|
||||
should be a vector of maximum values for each slice along axis.
|
||||
signed_input: True if the quantization is signed or unsigned.
|
||||
num_bits: The bitwidth of the quantization.
|
||||
range_given: If true use `input_min` and `input_max` for the range of the
|
||||
input, otherwise determine min and max from the input `Tensor`.
|
||||
round_mode: Rounding mode when rounding from float values to quantized ones.
|
||||
one of ['HALF_TO_EVEN', 'HALF_UP']
|
||||
name: Optional name for the operation.
|
||||
narrow_range: If true, then the absolute value of the quantized minimum
|
||||
value is the same as the quantized maximum value, instead of 1 greater.
|
||||
i.e. for 8 bit quantization, the minimum value is -127 instead of -128.
|
||||
axis: Integer. If specified, refers to a dimension of the input tensor, such
|
||||
that quantization will be per slice along that dimension.
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Each element is the result of quantizing and dequantizing the
|
||||
corresponding element of `input`.
|
||||
"""
|
||||
if axis is None:
|
||||
axis = -1
|
||||
elif axis < 0:
|
||||
if input.shape.ndims is None:
|
||||
raise ValueError("input should have known rank to use negative axis.")
|
||||
axis %= input.shape.ndims
|
||||
|
||||
return gen_array_ops.quantize_and_dequantize_v4(
|
||||
input,
|
||||
input_min=input_min,
|
||||
input_max=input_max,
|
||||
signed_input=signed_input,
|
||||
num_bits=num_bits,
|
||||
range_given=range_given,
|
||||
round_mode=round_mode,
|
||||
narrow_range=narrow_range,
|
||||
axis=axis,
|
||||
name=name)
|
||||
|
||||
|
||||
@tf_export("searchsorted")
|
||||
@dispatch.add_dispatch_support
|
||||
def searchsorted(sorted_sequence,
|
||||
@ -6175,7 +6286,7 @@ def _with_nonzero_rank(data):
|
||||
@dispatch.add_dispatch_support
|
||||
def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
|
||||
"""Repeat elements of `input`.
|
||||
|
||||
|
||||
See also `tf.concat`, `tf.stack`, `tf.tile`.
|
||||
|
||||
Args:
|
||||
|
@ -1824,6 +1824,10 @@ tf_module {
|
||||
name: "quantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'round_mode\', \'name\', \'narrow_range\', \'axis\', \'ensure_minimum_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'HALF_AWAY_FROM_ZERO\', \'None\', \'False\', \'None\', \'0.01\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "quantize_and_dequantize_v4"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "quantize_v2"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'name\', \'round_mode\', \'narrow_range\', \'axis\', \'ensure_minimum_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'HALF_AWAY_FROM_ZERO\', \'False\', \'None\', \'0.01\'], "
|
||||
|
@ -36,6 +36,10 @@ tf_module {
|
||||
name: "quantize_and_dequantize"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\', \'narrow_range\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "quantize_and_dequantize_v2"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\', \'narrow_range\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "quantized_concat"
|
||||
argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2940,6 +2940,14 @@ tf_module {
|
||||
name: "QuantizeAndDequantizeV3"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'num_bits\', \'signed_input\', \'range_given\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizeAndDequantizeV4"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizeAndDequantizeV4Grad"
|
||||
argspec: "args=[\'gradients\', \'input\', \'input_min\', \'input_max\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizeDownAndShrinkRange"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -880,6 +880,10 @@ tf_module {
|
||||
name: "py_function"
|
||||
argspec: "args=[\'func\', \'inp\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "quantize_and_dequantize_v4"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "range"
|
||||
argspec: "args=[\'start\', \'limit\', \'delta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'range\'], "
|
||||
|
@ -36,6 +36,10 @@ tf_module {
|
||||
name: "quantize_and_dequantize"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\', \'narrow_range\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "quantize_and_dequantize_v2"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\', \'narrow_range\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "quantized_concat"
|
||||
argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2940,6 +2940,14 @@ tf_module {
|
||||
name: "QuantizeAndDequantizeV3"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'num_bits\', \'signed_input\', \'range_given\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizeAndDequantizeV4"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizeAndDequantizeV4Grad"
|
||||
argspec: "args=[\'gradients\', \'input\', \'input_min\', \'input_max\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizeDownAndShrinkRange"
|
||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user