Support Dequantize to bfloat16.
Introduce DequantizeV2 which allows user to specify the output dtype{float|bfloat16}.
PiperOrigin-RevId: 289688216
Change-Id: I6550ae555e8895a759f36ffc0da8bc496fa7554a
This commit is contained in:
parent
8a2a86318b
commit
83df634c7e
@ -55,7 +55,6 @@ class DequantizeOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis));
|
||||||
OP_REQUIRES(ctx, axis == -1,
|
OP_REQUIRES(ctx, axis == -1,
|
||||||
errors::InvalidArgument("axis must be -1' is ", axis));
|
errors::InvalidArgument("axis must be -1' is ", axis));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
~DequantizeOp() override = default;
|
~DequantizeOp() override = default;
|
||||||
@ -87,6 +86,7 @@ class DequantizeOp : public XlaOpKernel {
|
|||||||
xla::XlaOp input = ctx->Input(0);
|
xla::XlaOp input = ctx->Input(0);
|
||||||
xla::XlaOp output;
|
xla::XlaOp output;
|
||||||
|
|
||||||
|
// TODO(ylc): Support bfloat16.
|
||||||
output = xla::ConvertElementType(input, xla::F32);
|
output = xla::ConvertElementType(input, xla::F32);
|
||||||
|
|
||||||
auto scale = ScalarLike(output, scale_factor);
|
auto scale = ScalarLike(output, scale_factor);
|
||||||
@ -94,14 +94,8 @@ class DequantizeOp : public XlaOpKernel {
|
|||||||
output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale),
|
output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale),
|
||||||
ScalarLike(output, min_range));
|
ScalarLike(output, min_range));
|
||||||
|
|
||||||
if (dtype_ == DT_BFLOAT16) {
|
|
||||||
output = xla::ConvertElementType(input, xla::BF16);
|
|
||||||
}
|
|
||||||
ctx->SetOutput(0, output);
|
ctx->SetOutput(0, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
DataType dtype_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("Dequantize").TypeConstraint("T", kQuantizedType),
|
REGISTER_XLA_OP(Name("Dequantize").TypeConstraint("T", kQuantizedType),
|
||||||
|
|||||||
@ -12,14 +12,7 @@ END
|
|||||||
The maximum scalar value possibly produced for the input.
|
The maximum scalar value possibly produced for the input.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
attr {
|
summary: "Dequantize the \'input\' tensor into a float Tensor."
|
||||||
name: "dtype"
|
|
||||||
description: <<END
|
|
||||||
Type of the output tensor. Currently Dequantize supports float and bfloat16.
|
|
||||||
If 'dtype' is 'bfloat16', it only supports 'MIN_COMBINED' mode.
|
|
||||||
END
|
|
||||||
}
|
|
||||||
summary: "Dequantize the \'input\' tensor into a float or bfloat16 Tensor."
|
|
||||||
description: <<END
|
description: <<END
|
||||||
[min_range, max_range] are scalar floats that specify the range for
|
[min_range, max_range] are scalar floats that specify the range for
|
||||||
the output. The 'mode' attribute controls exactly which calculations are
|
the output. The 'mode' attribute controls exactly which calculations are
|
||||||
|
|||||||
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/meta_support.h"
|
#include "tensorflow/core/kernels/meta_support.h"
|
||||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -38,44 +37,18 @@ namespace tensorflow {
|
|||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
template <typename T>
|
template <typename Device, typename T>
|
||||||
T Cast(float v) {
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
bfloat16 Cast<bfloat16>(float v) {
|
|
||||||
return bfloat16(v);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Device, typename T, typename S>
|
|
||||||
class DequantizeOp : public OpKernel {
|
class DequantizeOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
string mode_string;
|
string mode_string;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(ctx,
|
||||||
ctx,
|
(mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" ||
|
||||||
(ctx->output_type(0) == DT_FLOAT || ctx->output_type(0) == DT_BFLOAT16),
|
mode_string == "SCALED"),
|
||||||
errors::InvalidArgument("Output type must be bfloat16 or float,"
|
errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
|
||||||
" is '" +
|
" 'MIN_FIRST', or 'SCALED', is '" +
|
||||||
DataTypeString(ctx->output_type(0)) + "'"));
|
mode_string + "'"));
|
||||||
|
|
||||||
if (ctx->output_type(0) == DT_FLOAT) {
|
|
||||||
OP_REQUIRES(ctx,
|
|
||||||
(mode_string == "MIN_COMBINED" ||
|
|
||||||
mode_string == "MIN_FIRST" || mode_string == "SCALED"),
|
|
||||||
errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
|
|
||||||
" 'MIN_FIRST', or 'SCALED', is '" +
|
|
||||||
mode_string + "'"));
|
|
||||||
} else {
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, (mode_string == "MIN_COMBINED"),
|
|
||||||
errors::InvalidArgument("When output type is bfloat16, Mode"
|
|
||||||
" string must be 'MIN_COMBINED', is '" +
|
|
||||||
mode_string + "'"));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mode_string == "MIN_COMBINED") {
|
if (mode_string == "MIN_COMBINED") {
|
||||||
mode_ = QUANTIZE_MODE_MIN_COMBINED;
|
mode_ = QUANTIZE_MODE_MIN_COMBINED;
|
||||||
} else if (mode_string == "MIN_FIRST") {
|
} else if (mode_string == "MIN_FIRST") {
|
||||||
@ -98,40 +71,34 @@ class DequantizeOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
|
||||||
if (num_slices == 1) {
|
if (num_slices == 1) {
|
||||||
const float min_range = input_min_tensor.flat<float>()(0);
|
const float min_range = input_min_tensor.flat<float>()(0);
|
||||||
const float max_range = input_max_tensor.flat<float>()(0);
|
const float max_range = input_max_tensor.flat<float>()(0);
|
||||||
DequantizeTensor(ctx, input, min_range, max_range, &float_output);
|
DequantizeTensor(ctx, input, min_range, max_range, output);
|
||||||
} else {
|
return;
|
||||||
OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST,
|
|
||||||
errors::Unimplemented("MIN_FIRST mode is not implemented for "
|
|
||||||
"Dequantize with axis != -1."));
|
|
||||||
|
|
||||||
int64 pre_dim = 1, post_dim = 1;
|
|
||||||
for (int i = 0; i < axis_; ++i) {
|
|
||||||
pre_dim *= float_output.dim_size(i);
|
|
||||||
}
|
|
||||||
for (int i = axis_ + 1; i < float_output.dims(); ++i) {
|
|
||||||
post_dim *= float_output.dim_size(i);
|
|
||||||
}
|
|
||||||
auto input_tensor = input.template bit_casted_shaped<T, 3>(
|
|
||||||
{pre_dim, num_slices, post_dim});
|
|
||||||
auto output_tensor =
|
|
||||||
float_output.flat_inner_outer_dims<float, 3>(axis_ - 1);
|
|
||||||
auto min_ranges = input_min_tensor.vec<float>();
|
|
||||||
auto max_ranges = input_max_tensor.vec<float>();
|
|
||||||
for (int i = 0; i < num_slices; ++i) {
|
|
||||||
DequantizeSlice(ctx->eigen_device<Device>(), ctx,
|
|
||||||
input_tensor.template chip<1>(i), min_ranges(i),
|
|
||||||
max_ranges(i), output_tensor.template chip<1>(i));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
S* out_ptr = output->flat<S>().data();
|
|
||||||
float* in_ptr = float_output.flat<float>().data();
|
OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST,
|
||||||
for (int64 i = 0; i < float_output.NumElements(); ++i) {
|
errors::Unimplemented("MIN_FIRST mode is not implemented for "
|
||||||
out_ptr[i] = static_cast<S>(in_ptr[i]);
|
"Dequantize with axis != -1."));
|
||||||
|
|
||||||
|
int64 pre_dim = 1, post_dim = 1;
|
||||||
|
for (int i = 0; i < axis_; ++i) {
|
||||||
|
pre_dim *= output->dim_size(i);
|
||||||
|
}
|
||||||
|
for (int i = axis_ + 1; i < output->dims(); ++i) {
|
||||||
|
post_dim *= output->dim_size(i);
|
||||||
|
}
|
||||||
|
auto input_tensor =
|
||||||
|
input.template bit_casted_shaped<T, 3>({pre_dim, num_slices, post_dim});
|
||||||
|
auto output_tensor = output->flat_inner_outer_dims<float, 3>(axis_ - 1);
|
||||||
|
auto min_ranges = input_min_tensor.vec<float>();
|
||||||
|
auto max_ranges = input_max_tensor.vec<float>();
|
||||||
|
for (int i = 0; i < num_slices; ++i) {
|
||||||
|
DequantizeSlice(ctx->eigen_device<Device>(), ctx,
|
||||||
|
input_tensor.template chip<1>(i), min_ranges(i),
|
||||||
|
max_ranges(i), output_tensor.template chip<1>(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,55 +188,21 @@ class DequantizeOp : public OpKernel {
|
|||||||
bool narrow_range_;
|
bool narrow_range_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
REGISTER_KERNEL_BUILDER(
|
||||||
.Device(DEVICE_CPU)
|
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
|
||||||
.TypeConstraint<quint8>("T")
|
DequantizeOp<CPUDevice, quint8>);
|
||||||
.TypeConstraint<float>("dtype"),
|
REGISTER_KERNEL_BUILDER(
|
||||||
DequantizeOp<CPUDevice, quint8, float>);
|
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
DequantizeOp<CPUDevice, qint8>);
|
||||||
.Device(DEVICE_CPU)
|
REGISTER_KERNEL_BUILDER(
|
||||||
.TypeConstraint<qint8>("T")
|
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint16>("T"),
|
||||||
.TypeConstraint<float>("dtype"),
|
DequantizeOp<CPUDevice, quint16>);
|
||||||
DequantizeOp<CPUDevice, qint8, float>);
|
REGISTER_KERNEL_BUILDER(
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
|
||||||
.Device(DEVICE_CPU)
|
DequantizeOp<CPUDevice, qint16>);
|
||||||
.TypeConstraint<quint16>("T")
|
|
||||||
.TypeConstraint<float>("dtype"),
|
REGISTER_KERNEL_BUILDER(
|
||||||
DequantizeOp<CPUDevice, quint16, float>);
|
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
DequantizeOp<CPUDevice, qint32>);
|
||||||
.Device(DEVICE_CPU)
|
|
||||||
.TypeConstraint<qint16>("T")
|
|
||||||
.TypeConstraint<float>("dtype"),
|
|
||||||
DequantizeOp<CPUDevice, qint16, float>);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
|
||||||
.Device(DEVICE_CPU)
|
|
||||||
.TypeConstraint<qint32>("T")
|
|
||||||
.TypeConstraint<float>("dtype"),
|
|
||||||
DequantizeOp<CPUDevice, qint32, float>);
|
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
|
||||||
.Device(DEVICE_CPU)
|
|
||||||
.TypeConstraint<quint8>("T")
|
|
||||||
.TypeConstraint<bfloat16>("dtype"),
|
|
||||||
DequantizeOp<CPUDevice, quint8, bfloat16>);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
|
||||||
.Device(DEVICE_CPU)
|
|
||||||
.TypeConstraint<qint8>("T")
|
|
||||||
.TypeConstraint<bfloat16>("dtype"),
|
|
||||||
DequantizeOp<CPUDevice, qint8, bfloat16>);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
|
||||||
.Device(DEVICE_CPU)
|
|
||||||
.TypeConstraint<quint16>("T")
|
|
||||||
.TypeConstraint<bfloat16>("dtype"),
|
|
||||||
DequantizeOp<CPUDevice, quint16, bfloat16>);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
|
||||||
.Device(DEVICE_CPU)
|
|
||||||
.TypeConstraint<qint16>("T")
|
|
||||||
.TypeConstraint<bfloat16>("dtype"),
|
|
||||||
DequantizeOp<CPUDevice, qint16, bfloat16>);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
|
||||||
.Device(DEVICE_CPU)
|
|
||||||
.TypeConstraint<qint32>("T")
|
|
||||||
.TypeConstraint<bfloat16>("dtype"),
|
|
||||||
DequantizeOp<CPUDevice, qint32, bfloat16>);
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -28,7 +28,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
|
||||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
@ -62,9 +61,8 @@ class DequantizeOpTest : public OpsTestBase {
|
|||||||
// Compares dequantize min vs the same using eigen. This tests that a change
|
// Compares dequantize min vs the same using eigen. This tests that a change
|
||||||
// to not use eigen gives equivalent results to using eigen.
|
// to not use eigen gives equivalent results to using eigen.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RunDequantizeMinCombinedTest(float min_range, float max_range,
|
void RunDequantizeMinCombinedTest(float min_range, float max_range) {
|
||||||
const string& op_name) {
|
TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "Dequantize")
|
||||||
TF_ASSERT_OK(NodeDefBuilder("dequantize_op", op_name)
|
|
||||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
@ -89,40 +87,6 @@ class DequantizeOpTest : public OpsTestBase {
|
|||||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compares dequantize min vs the same using eigen. This tests that a change
|
|
||||||
// to not use eigen gives equivalent results to using eigen.
|
|
||||||
template <typename T>
|
|
||||||
void RunDequantizeBfloat16MinCombinedTest(float min_range, float max_range) {
|
|
||||||
TF_ASSERT_OK(NodeDefBuilder("dequantize_op_bfloat16", "Dequantize")
|
|
||||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Input(FakeInput(DT_FLOAT))
|
|
||||||
.Attr("T", DataTypeToEnum<T>::v())
|
|
||||||
.Attr("mode", "MIN_COMBINED")
|
|
||||||
.Attr("dtype", DT_BFLOAT16)
|
|
||||||
.Finalize(node_def()));
|
|
||||||
TF_ASSERT_OK(InitOp());
|
|
||||||
|
|
||||||
std::vector<T> input;
|
|
||||||
for (int64 i = std::numeric_limits<T>::min();
|
|
||||||
i < std::numeric_limits<T>::max(); ++i) {
|
|
||||||
input.push_back(static_cast<T>(i));
|
|
||||||
}
|
|
||||||
TensorShape shape({static_cast<int64>(input.size())});
|
|
||||||
AddInputFromArray<T>(shape, input);
|
|
||||||
AddInputFromArray<float>(TensorShape({}), {min_range});
|
|
||||||
AddInputFromArray<float>(TensorShape({}), {max_range});
|
|
||||||
TF_ASSERT_OK(RunOpKernel());
|
|
||||||
|
|
||||||
Tensor expected_float32(allocator(), DT_FLOAT, shape);
|
|
||||||
ComputeDequantizeMinCombinedUsingEigen<T>(GetInput(0), min_range, max_range,
|
|
||||||
&expected_float32);
|
|
||||||
Tensor expected(allocator(), DT_BFLOAT16, shape);
|
|
||||||
expected.flat<bfloat16>() = expected_float32.flat<float>().cast<bfloat16>();
|
|
||||||
|
|
||||||
test::ExpectTensorEqual<bfloat16>(expected, *GetOutput(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a tensor with the specified dims, using values chosen from data,
|
// Creates a tensor with the specified dims, using values chosen from data,
|
||||||
// multiplied by (1 + index) along the axis dimension.
|
// multiplied by (1 + index) along the axis dimension.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -187,29 +151,16 @@ struct ParameterizedDequantizeOpTest
|
|||||||
public ::testing::WithParamInterface<int> {};
|
public ::testing::WithParamInterface<int> {};
|
||||||
|
|
||||||
TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint8) {
|
TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint8) {
|
||||||
RunDequantizeMinCombinedTest<quint8>(0, 255.0f, "Dequantize");
|
RunDequantizeMinCombinedTest<quint8>(0, 255.0f);
|
||||||
}
|
}
|
||||||
TEST_F(DequantizeOpTest, DequantizeMinCombinedQint8) {
|
TEST_F(DequantizeOpTest, DequantizeMinCombinedQint8) {
|
||||||
RunDequantizeMinCombinedTest<qint8>(0, 255.0f, "Dequantize");
|
RunDequantizeMinCombinedTest<qint8>(0, 255.0f);
|
||||||
}
|
}
|
||||||
TEST_F(DequantizeOpTest, DequantizeMinCombinedQint16) {
|
TEST_F(DequantizeOpTest, DequantizeMinCombinedQint16) {
|
||||||
RunDequantizeMinCombinedTest<qint16>(0, 255.0f, "Dequantize");
|
RunDequantizeMinCombinedTest<qint16>(0, 255.0f);
|
||||||
}
|
}
|
||||||
TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint16) {
|
TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint16) {
|
||||||
RunDequantizeMinCombinedTest<quint16>(0, 255.0f, "Dequantize");
|
RunDequantizeMinCombinedTest<quint16>(0, 255.0f);
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQuint8) {
|
|
||||||
RunDequantizeBfloat16MinCombinedTest<quint8>(0, 255.0f);
|
|
||||||
}
|
|
||||||
TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQint8) {
|
|
||||||
RunDequantizeBfloat16MinCombinedTest<qint8>(0, 255.0f);
|
|
||||||
}
|
|
||||||
TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQint16) {
|
|
||||||
RunDequantizeBfloat16MinCombinedTest<qint16>(0, 255.0f);
|
|
||||||
}
|
|
||||||
TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQuint16) {
|
|
||||||
RunDequantizeBfloat16MinCombinedTest<quint16>(0, 255.0f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DequantizeOpTest, DequantizeScaledQuint8Zero) {
|
TEST_F(DequantizeOpTest, DequantizeScaledQuint8Zero) {
|
||||||
@ -251,10 +202,8 @@ static void BM_DequantizeMinCombinedCpu(int iters) {
|
|||||||
auto root = Scope::NewRootScope().ExitOnError();
|
auto root = Scope::NewRootScope().ExitOnError();
|
||||||
const int64 num_values = 1500 * 250;
|
const int64 num_values = 1500 * 250;
|
||||||
std::vector<T> inputs;
|
std::vector<T> inputs;
|
||||||
|
|
||||||
inputs.reserve(num_values);
|
inputs.reserve(num_values);
|
||||||
for (int i = 0; i < num_values; ++i) inputs.push_back(i);
|
for (int i = 0; i < num_values; ++i) inputs.push_back(i);
|
||||||
|
|
||||||
ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f),
|
ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f),
|
||||||
test::AsScalar<float>(20.5f),
|
test::AsScalar<float>(20.5f),
|
||||||
ops::Dequantize::Attrs().Mode("MIN_COMBINED"));
|
ops::Dequantize::Attrs().Mode("MIN_COMBINED"));
|
||||||
@ -288,47 +237,5 @@ BENCHMARK(BM_DequantizeMinCombinedCpuQint16);
|
|||||||
BENCHMARK(BM_DequantizeMinCombinedCpuQuint8);
|
BENCHMARK(BM_DequantizeMinCombinedCpuQuint8);
|
||||||
BENCHMARK(BM_DequantizeMinCombinedCpuQint8);
|
BENCHMARK(BM_DequantizeMinCombinedCpuQint8);
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static void BM_DequantizeBfloat16MinCombinedCpu(int iters) {
|
|
||||||
auto root = Scope::NewRootScope().ExitOnError();
|
|
||||||
const int64 num_values = 1500 * 250;
|
|
||||||
std::vector<T> inputs;
|
|
||||||
|
|
||||||
inputs.reserve(num_values);
|
|
||||||
for (int i = 0; i < num_values; ++i) inputs.push_back(i);
|
|
||||||
|
|
||||||
ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f),
|
|
||||||
test::AsScalar<float>(20.5f),
|
|
||||||
ops::Dequantize::Attrs().Dtype(DT_BFLOAT16));
|
|
||||||
TF_CHECK_OK(root.status());
|
|
||||||
Graph* g = new Graph(OpRegistry::Global());
|
|
||||||
TF_CHECK_OK(root.ToGraph(g));
|
|
||||||
|
|
||||||
test::Benchmark("cpu", g).Run(iters);
|
|
||||||
testing::BytesProcessed(iters * num_values * (sizeof(bfloat16) + sizeof(T)));
|
|
||||||
testing::ItemsProcessed(iters);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void BM_DequantizeBfloat16MinCombinedCpuQuint16(int iters) {
|
|
||||||
BM_DequantizeBfloat16MinCombinedCpu<quint16>(iters);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void BM_DequantizeBfloat16MinCombinedCpuQint16(int iters) {
|
|
||||||
BM_DequantizeBfloat16MinCombinedCpu<qint16>(iters);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void BM_DequantizeBfloat16MinCombinedCpuQuint8(int iters) {
|
|
||||||
BM_DequantizeBfloat16MinCombinedCpu<quint8>(iters);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void BM_DequantizeBfloat16MinCombinedCpuQint8(int iters) {
|
|
||||||
BM_DequantizeBfloat16MinCombinedCpu<qint8>(iters);
|
|
||||||
}
|
|
||||||
|
|
||||||
BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQuint16);
|
|
||||||
BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQint16);
|
|
||||||
BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQuint8);
|
|
||||||
BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQint8);
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -2871,12 +2871,11 @@ REGISTER_OP("Dequantize")
|
|||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Input("min_range: float")
|
.Input("min_range: float")
|
||||||
.Input("max_range: float")
|
.Input("max_range: float")
|
||||||
.Output("output: dtype")
|
.Output("output: float")
|
||||||
.Attr("T: quantizedtype")
|
.Attr("T: quantizedtype")
|
||||||
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
|
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
|
||||||
.Attr("narrow_range: bool = false")
|
.Attr("narrow_range: bool = false")
|
||||||
.Attr("axis: int = -1")
|
.Attr("axis: int = -1")
|
||||||
.Attr("dtype: {bfloat16, float} = DT_FLOAT")
|
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
int axis = -1;
|
int axis = -1;
|
||||||
Status s = c->GetAttr("axis", &axis);
|
Status s = c->GetAttr("axis", &axis);
|
||||||
|
|||||||
@ -248,76 +248,3 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
op {
|
|
||||||
name: "Dequantize"
|
|
||||||
input_arg {
|
|
||||||
name: "input"
|
|
||||||
type_attr: "T"
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "min_range"
|
|
||||||
type: DT_FLOAT
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "max_range"
|
|
||||||
type: DT_FLOAT
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "output"
|
|
||||||
type_attr: "dtype"
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "T"
|
|
||||||
type: "type"
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
type: DT_QINT8
|
|
||||||
type: DT_QUINT8
|
|
||||||
type: DT_QINT32
|
|
||||||
type: DT_QINT16
|
|
||||||
type: DT_QUINT16
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "mode"
|
|
||||||
type: "string"
|
|
||||||
default_value {
|
|
||||||
s: "MIN_COMBINED"
|
|
||||||
}
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
s: "MIN_COMBINED"
|
|
||||||
s: "MIN_FIRST"
|
|
||||||
s: "SCALED"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "narrow_range"
|
|
||||||
type: "bool"
|
|
||||||
default_value {
|
|
||||||
b: false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "axis"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "dtype"
|
|
||||||
type: "type"
|
|
||||||
default_value {
|
|
||||||
type: DT_FLOAT
|
|
||||||
}
|
|
||||||
allowed_values {
|
|
||||||
list {
|
|
||||||
type: DT_BFLOAT16
|
|
||||||
type: DT_FLOAT
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@ -4982,8 +4982,7 @@ def dequantize( # pylint: disable=missing-docstring
|
|||||||
mode="MIN_COMBINED",
|
mode="MIN_COMBINED",
|
||||||
name=None,
|
name=None,
|
||||||
axis=None,
|
axis=None,
|
||||||
narrow_range=False,
|
narrow_range=False):
|
||||||
dtype=dtypes.float32):
|
|
||||||
if axis is None:
|
if axis is None:
|
||||||
axis = -1
|
axis = -1
|
||||||
elif axis < 0:
|
elif axis < 0:
|
||||||
@ -4993,17 +4992,10 @@ def dequantize( # pylint: disable=missing-docstring
|
|||||||
|
|
||||||
if axis >= 0 or narrow_range:
|
if axis >= 0 or narrow_range:
|
||||||
return gen_array_ops.dequantize(
|
return gen_array_ops.dequantize(
|
||||||
input,
|
input, min_range, max_range, mode=mode, name=name,
|
||||||
min_range,
|
narrow_range=narrow_range, axis=axis)
|
||||||
max_range,
|
|
||||||
mode=mode,
|
|
||||||
name=name,
|
|
||||||
narrow_range=narrow_range,
|
|
||||||
axis=axis,
|
|
||||||
dtype=dtype)
|
|
||||||
return gen_array_ops.dequantize(
|
return gen_array_ops.dequantize(
|
||||||
input, min_range, max_range, mode=mode, name=name, dtype=dtype)
|
input, min_range, max_range, mode=mode, name=name)
|
||||||
|
|
||||||
|
|
||||||
dequantize.__doc__ = gen_array_ops.dequantize.__doc__
|
dequantize.__doc__ = gen_array_ops.dequantize.__doc__
|
||||||
|
|
||||||
|
|||||||
@ -1110,7 +1110,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "dequantize"
|
name: "dequantize"
|
||||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
|
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "deserialize_many_sparse"
|
name: "deserialize_many_sparse"
|
||||||
|
|||||||
@ -2,7 +2,7 @@ path: "tensorflow.quantization"
|
|||||||
tf_module {
|
tf_module {
|
||||||
member_method {
|
member_method {
|
||||||
name: "dequantize"
|
name: "dequantize"
|
||||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
|
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "fake_quant_with_min_max_args"
|
name: "fake_quant_with_min_max_args"
|
||||||
|
|||||||
@ -1082,7 +1082,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Dequantize"
|
name: "Dequantize"
|
||||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DeserializeIterator"
|
name: "DeserializeIterator"
|
||||||
|
|||||||
@ -2,7 +2,7 @@ path: "tensorflow.quantization"
|
|||||||
tf_module {
|
tf_module {
|
||||||
member_method {
|
member_method {
|
||||||
name: "dequantize"
|
name: "dequantize"
|
||||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
|
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "fake_quant_with_min_max_args"
|
name: "fake_quant_with_min_max_args"
|
||||||
|
|||||||
@ -1082,7 +1082,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Dequantize"
|
name: "Dequantize"
|
||||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DeserializeIterator"
|
name: "DeserializeIterator"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user