Support Dequantize to bfloat16.
Introduce DequantizeV2 which allows user to specify the output dtype{float|bfloat16}. PiperOrigin-RevId: 289699810 Change-Id: Idb12a52b6b9c18d015278b5c9aa4fd347a109b60
This commit is contained in:
parent
29e4767dbe
commit
cff8012de1
tensorflow
compiler/tf2xla/kernels
core
api_def/base_api
kernels
ops
python/ops
tools/api/golden
@ -55,6 +55,7 @@ class DequantizeOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis));
|
||||
OP_REQUIRES(ctx, axis == -1,
|
||||
errors::InvalidArgument("axis must be -1' is ", axis));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
~DequantizeOp() override = default;
|
||||
@ -86,7 +87,6 @@ class DequantizeOp : public XlaOpKernel {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
xla::XlaOp output;
|
||||
|
||||
// TODO(ylc): Support bfloat16.
|
||||
output = xla::ConvertElementType(input, xla::F32);
|
||||
|
||||
auto scale = ScalarLike(output, scale_factor);
|
||||
@ -94,8 +94,14 @@ class DequantizeOp : public XlaOpKernel {
|
||||
output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale),
|
||||
ScalarLike(output, min_range));
|
||||
|
||||
if (dtype_ == DT_BFLOAT16) {
|
||||
output = xla::ConvertElementType(input, xla::BF16);
|
||||
}
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Dequantize").TypeConstraint("T", kQuantizedType),
|
||||
|
@ -12,7 +12,14 @@ END
|
||||
The maximum scalar value possibly produced for the input.
|
||||
END
|
||||
}
|
||||
summary: "Dequantize the \'input\' tensor into a float Tensor."
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
Type of the output tensor. Currently Dequantize supports float and bfloat16.
|
||||
If 'dtype' is 'bfloat16', it only supports 'MIN_COMBINED' mode.
|
||||
END
|
||||
}
|
||||
summary: "Dequantize the \'input\' tensor into a float or bfloat16 Tensor."
|
||||
description: <<END
|
||||
[min_range, max_range] are scalar floats that specify the range for
|
||||
the output. The 'mode' attribute controls exactly which calculations are
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/meta_support.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace {
|
||||
@ -37,18 +38,44 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename T>
|
||||
T Cast(float v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
template <>
|
||||
bfloat16 Cast<bfloat16>(float v) {
|
||||
return bfloat16(v);
|
||||
}
|
||||
|
||||
template <typename Device, typename T, typename S>
|
||||
class DequantizeOp : public OpKernel {
|
||||
public:
|
||||
explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
string mode_string;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
|
||||
OP_REQUIRES(ctx,
|
||||
(mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" ||
|
||||
mode_string == "SCALED"),
|
||||
errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
|
||||
" 'MIN_FIRST', or 'SCALED', is '" +
|
||||
mode_string + "'"));
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
(ctx->output_type(0) == DT_FLOAT || ctx->output_type(0) == DT_BFLOAT16),
|
||||
errors::InvalidArgument("Output type must be bfloat16 or float,"
|
||||
" is '" +
|
||||
DataTypeString(ctx->output_type(0)) + "'"));
|
||||
|
||||
if (ctx->output_type(0) == DT_FLOAT) {
|
||||
OP_REQUIRES(ctx,
|
||||
(mode_string == "MIN_COMBINED" ||
|
||||
mode_string == "MIN_FIRST" || mode_string == "SCALED"),
|
||||
errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
|
||||
" 'MIN_FIRST', or 'SCALED', is '" +
|
||||
mode_string + "'"));
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx, (mode_string == "MIN_COMBINED"),
|
||||
errors::InvalidArgument("When output type is bfloat16, Mode"
|
||||
" string must be 'MIN_COMBINED', is '" +
|
||||
mode_string + "'"));
|
||||
}
|
||||
|
||||
if (mode_string == "MIN_COMBINED") {
|
||||
mode_ = QUANTIZE_MODE_MIN_COMBINED;
|
||||
} else if (mode_string == "MIN_FIRST") {
|
||||
@ -71,34 +98,40 @@ class DequantizeOp : public OpKernel {
|
||||
}
|
||||
|
||||
Tensor* output = nullptr;
|
||||
Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
|
||||
if (num_slices == 1) {
|
||||
const float min_range = input_min_tensor.flat<float>()(0);
|
||||
const float max_range = input_max_tensor.flat<float>()(0);
|
||||
DequantizeTensor(ctx, input, min_range, max_range, output);
|
||||
return;
|
||||
}
|
||||
DequantizeTensor(ctx, input, min_range, max_range, &float_output);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST,
|
||||
errors::Unimplemented("MIN_FIRST mode is not implemented for "
|
||||
"Dequantize with axis != -1."));
|
||||
|
||||
OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST,
|
||||
errors::Unimplemented("MIN_FIRST mode is not implemented for "
|
||||
"Dequantize with axis != -1."));
|
||||
|
||||
int64 pre_dim = 1, post_dim = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
pre_dim *= output->dim_size(i);
|
||||
int64 pre_dim = 1, post_dim = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
pre_dim *= float_output.dim_size(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < float_output.dims(); ++i) {
|
||||
post_dim *= float_output.dim_size(i);
|
||||
}
|
||||
auto input_tensor = input.template bit_casted_shaped<T, 3>(
|
||||
{pre_dim, num_slices, post_dim});
|
||||
auto output_tensor =
|
||||
float_output.flat_inner_outer_dims<float, 3>(axis_ - 1);
|
||||
auto min_ranges = input_min_tensor.vec<float>();
|
||||
auto max_ranges = input_max_tensor.vec<float>();
|
||||
for (int i = 0; i < num_slices; ++i) {
|
||||
DequantizeSlice(ctx->eigen_device<Device>(), ctx,
|
||||
input_tensor.template chip<1>(i), min_ranges(i),
|
||||
max_ranges(i), output_tensor.template chip<1>(i));
|
||||
}
|
||||
}
|
||||
for (int i = axis_ + 1; i < output->dims(); ++i) {
|
||||
post_dim *= output->dim_size(i);
|
||||
}
|
||||
auto input_tensor =
|
||||
input.template bit_casted_shaped<T, 3>({pre_dim, num_slices, post_dim});
|
||||
auto output_tensor = output->flat_inner_outer_dims<float, 3>(axis_ - 1);
|
||||
auto min_ranges = input_min_tensor.vec<float>();
|
||||
auto max_ranges = input_max_tensor.vec<float>();
|
||||
for (int i = 0; i < num_slices; ++i) {
|
||||
DequantizeSlice(ctx->eigen_device<Device>(), ctx,
|
||||
input_tensor.template chip<1>(i), min_ranges(i),
|
||||
max_ranges(i), output_tensor.template chip<1>(i));
|
||||
S* out_ptr = output->flat<S>().data();
|
||||
float* in_ptr = float_output.flat<float>().data();
|
||||
for (int64 i = 0; i < float_output.NumElements(); ++i) {
|
||||
out_ptr[i] = static_cast<S>(in_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,21 +221,55 @@ class DequantizeOp : public OpKernel {
|
||||
bool narrow_range_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
|
||||
DequantizeOp<CPUDevice, quint8>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
|
||||
DequantizeOp<CPUDevice, qint8>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint16>("T"),
|
||||
DequantizeOp<CPUDevice, quint16>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
|
||||
DequantizeOp<CPUDevice, qint16>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
|
||||
DequantizeOp<CPUDevice, qint32>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint8>("T")
|
||||
.TypeConstraint<float>("dtype"),
|
||||
DequantizeOp<CPUDevice, quint8, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint8>("T")
|
||||
.TypeConstraint<float>("dtype"),
|
||||
DequantizeOp<CPUDevice, qint8, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint16>("T")
|
||||
.TypeConstraint<float>("dtype"),
|
||||
DequantizeOp<CPUDevice, quint16, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint16>("T")
|
||||
.TypeConstraint<float>("dtype"),
|
||||
DequantizeOp<CPUDevice, qint16, float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint32>("T")
|
||||
.TypeConstraint<float>("dtype"),
|
||||
DequantizeOp<CPUDevice, qint32, float>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint8>("T")
|
||||
.TypeConstraint<bfloat16>("dtype"),
|
||||
DequantizeOp<CPUDevice, quint8, bfloat16>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint8>("T")
|
||||
.TypeConstraint<bfloat16>("dtype"),
|
||||
DequantizeOp<CPUDevice, qint8, bfloat16>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint16>("T")
|
||||
.TypeConstraint<bfloat16>("dtype"),
|
||||
DequantizeOp<CPUDevice, quint16, bfloat16>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint16>("T")
|
||||
.TypeConstraint<bfloat16>("dtype"),
|
||||
DequantizeOp<CPUDevice, qint16, bfloat16>);
|
||||
REGISTER_KERNEL_BUILDER(Name("Dequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint32>("T")
|
||||
.TypeConstraint<bfloat16>("dtype"),
|
||||
DequantizeOp<CPUDevice, qint32, bfloat16>);
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
@ -61,8 +62,9 @@ class DequantizeOpTest : public OpsTestBase {
|
||||
// Compares dequantize min vs the same using eigen. This tests that a change
|
||||
// to not use eigen gives equivalent results to using eigen.
|
||||
template <typename T>
|
||||
void RunDequantizeMinCombinedTest(float min_range, float max_range) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "Dequantize")
|
||||
void RunDequantizeMinCombinedTest(float min_range, float max_range,
|
||||
const string& op_name) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("dequantize_op", op_name)
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
@ -87,6 +89,40 @@ class DequantizeOpTest : public OpsTestBase {
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
// Compares dequantize min vs the same using eigen. This tests that a change
|
||||
// to not use eigen gives equivalent results to using eigen.
|
||||
template <typename T>
|
||||
void RunDequantizeBfloat16MinCombinedTest(float min_range, float max_range) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("dequantize_op_bfloat16", "Dequantize")
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("T", DataTypeToEnum<T>::v())
|
||||
.Attr("mode", "MIN_COMBINED")
|
||||
.Attr("dtype", DT_BFLOAT16)
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
|
||||
std::vector<T> input;
|
||||
for (int64 i = std::numeric_limits<T>::min();
|
||||
i < std::numeric_limits<T>::max(); ++i) {
|
||||
input.push_back(static_cast<T>(i));
|
||||
}
|
||||
TensorShape shape({static_cast<int64>(input.size())});
|
||||
AddInputFromArray<T>(shape, input);
|
||||
AddInputFromArray<float>(TensorShape({}), {min_range});
|
||||
AddInputFromArray<float>(TensorShape({}), {max_range});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected_float32(allocator(), DT_FLOAT, shape);
|
||||
ComputeDequantizeMinCombinedUsingEigen<T>(GetInput(0), min_range, max_range,
|
||||
&expected_float32);
|
||||
Tensor expected(allocator(), DT_BFLOAT16, shape);
|
||||
expected.flat<bfloat16>() = expected_float32.flat<float>().cast<bfloat16>();
|
||||
|
||||
test::ExpectTensorEqual<bfloat16>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
// Creates a tensor with the specified dims, using values chosen from data,
|
||||
// multiplied by (1 + index) along the axis dimension.
|
||||
template <typename T>
|
||||
@ -151,16 +187,29 @@ struct ParameterizedDequantizeOpTest
|
||||
public ::testing::WithParamInterface<int> {};
|
||||
|
||||
TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint8) {
|
||||
RunDequantizeMinCombinedTest<quint8>(0, 255.0f);
|
||||
RunDequantizeMinCombinedTest<quint8>(0, 255.0f, "Dequantize");
|
||||
}
|
||||
TEST_F(DequantizeOpTest, DequantizeMinCombinedQint8) {
|
||||
RunDequantizeMinCombinedTest<qint8>(0, 255.0f);
|
||||
RunDequantizeMinCombinedTest<qint8>(0, 255.0f, "Dequantize");
|
||||
}
|
||||
TEST_F(DequantizeOpTest, DequantizeMinCombinedQint16) {
|
||||
RunDequantizeMinCombinedTest<qint16>(0, 255.0f);
|
||||
RunDequantizeMinCombinedTest<qint16>(0, 255.0f, "Dequantize");
|
||||
}
|
||||
TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint16) {
|
||||
RunDequantizeMinCombinedTest<quint16>(0, 255.0f);
|
||||
RunDequantizeMinCombinedTest<quint16>(0, 255.0f, "Dequantize");
|
||||
}
|
||||
|
||||
TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQuint8) {
|
||||
RunDequantizeBfloat16MinCombinedTest<quint8>(0, 255.0f);
|
||||
}
|
||||
TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQint8) {
|
||||
RunDequantizeBfloat16MinCombinedTest<qint8>(0, 255.0f);
|
||||
}
|
||||
TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQint16) {
|
||||
RunDequantizeBfloat16MinCombinedTest<qint16>(0, 255.0f);
|
||||
}
|
||||
TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQuint16) {
|
||||
RunDequantizeBfloat16MinCombinedTest<quint16>(0, 255.0f);
|
||||
}
|
||||
|
||||
TEST_F(DequantizeOpTest, DequantizeScaledQuint8Zero) {
|
||||
@ -202,8 +251,10 @@ static void BM_DequantizeMinCombinedCpu(int iters) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
const int64 num_values = 1500 * 250;
|
||||
std::vector<T> inputs;
|
||||
|
||||
inputs.reserve(num_values);
|
||||
for (int i = 0; i < num_values; ++i) inputs.push_back(i);
|
||||
|
||||
ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f),
|
||||
test::AsScalar<float>(20.5f),
|
||||
ops::Dequantize::Attrs().Mode("MIN_COMBINED"));
|
||||
@ -237,5 +288,47 @@ BENCHMARK(BM_DequantizeMinCombinedCpuQint16);
|
||||
BENCHMARK(BM_DequantizeMinCombinedCpuQuint8);
|
||||
BENCHMARK(BM_DequantizeMinCombinedCpuQint8);
|
||||
|
||||
template <typename T>
|
||||
static void BM_DequantizeBfloat16MinCombinedCpu(int iters) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
const int64 num_values = 1500 * 250;
|
||||
std::vector<T> inputs;
|
||||
|
||||
inputs.reserve(num_values);
|
||||
for (int i = 0; i < num_values; ++i) inputs.push_back(i);
|
||||
|
||||
ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f),
|
||||
test::AsScalar<float>(20.5f),
|
||||
ops::Dequantize::Attrs().Dtype(DT_BFLOAT16));
|
||||
TF_CHECK_OK(root.status());
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
TF_CHECK_OK(root.ToGraph(g));
|
||||
|
||||
test::Benchmark("cpu", g).Run(iters);
|
||||
testing::BytesProcessed(iters * num_values * (sizeof(bfloat16) + sizeof(T)));
|
||||
testing::ItemsProcessed(iters);
|
||||
}
|
||||
|
||||
static void BM_DequantizeBfloat16MinCombinedCpuQuint16(int iters) {
|
||||
BM_DequantizeBfloat16MinCombinedCpu<quint16>(iters);
|
||||
}
|
||||
|
||||
static void BM_DequantizeBfloat16MinCombinedCpuQint16(int iters) {
|
||||
BM_DequantizeBfloat16MinCombinedCpu<qint16>(iters);
|
||||
}
|
||||
|
||||
static void BM_DequantizeBfloat16MinCombinedCpuQuint8(int iters) {
|
||||
BM_DequantizeBfloat16MinCombinedCpu<quint8>(iters);
|
||||
}
|
||||
|
||||
static void BM_DequantizeBfloat16MinCombinedCpuQint8(int iters) {
|
||||
BM_DequantizeBfloat16MinCombinedCpu<qint8>(iters);
|
||||
}
|
||||
|
||||
BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQuint16);
|
||||
BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQint16);
|
||||
BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQuint8);
|
||||
BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQint8);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -2871,11 +2871,12 @@ REGISTER_OP("Dequantize")
|
||||
.Input("input: T")
|
||||
.Input("min_range: float")
|
||||
.Input("max_range: float")
|
||||
.Output("output: float")
|
||||
.Output("output: dtype")
|
||||
.Attr("T: quantizedtype")
|
||||
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
|
||||
.Attr("narrow_range: bool = false")
|
||||
.Attr("axis: int = -1")
|
||||
.Attr("dtype: {bfloat16, float} = DT_FLOAT")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
int axis = -1;
|
||||
Status s = c->GetAttr("axis", &axis);
|
||||
|
@ -248,3 +248,76 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Dequantize"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "min_range"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "max_range"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "dtype"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_QINT16
|
||||
type: DT_QUINT16
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "mode"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "MIN_COMBINED"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "MIN_COMBINED"
|
||||
s: "MIN_FIRST"
|
||||
s: "SCALED"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "narrow_range"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "axis"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: -1
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4982,7 +4982,8 @@ def dequantize( # pylint: disable=missing-docstring
|
||||
mode="MIN_COMBINED",
|
||||
name=None,
|
||||
axis=None,
|
||||
narrow_range=False):
|
||||
narrow_range=False,
|
||||
dtype=dtypes.float32):
|
||||
if axis is None:
|
||||
axis = -1
|
||||
elif axis < 0:
|
||||
@ -4992,10 +4993,17 @@ def dequantize( # pylint: disable=missing-docstring
|
||||
|
||||
if axis >= 0 or narrow_range:
|
||||
return gen_array_ops.dequantize(
|
||||
input, min_range, max_range, mode=mode, name=name,
|
||||
narrow_range=narrow_range, axis=axis)
|
||||
input,
|
||||
min_range,
|
||||
max_range,
|
||||
mode=mode,
|
||||
name=name,
|
||||
narrow_range=narrow_range,
|
||||
axis=axis,
|
||||
dtype=dtype)
|
||||
return gen_array_ops.dequantize(
|
||||
input, min_range, max_range, mode=mode, name=name)
|
||||
input, min_range, max_range, mode=mode, name=name, dtype=dtype)
|
||||
|
||||
|
||||
dequantize.__doc__ = gen_array_ops.dequantize.__doc__
|
||||
|
||||
|
@ -1110,7 +1110,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "deserialize_many_sparse"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.quantization"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "fake_quant_with_min_max_args"
|
||||
|
@ -1082,7 +1082,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeserializeIterator"
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.quantization"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "fake_quant_with_min_max_args"
|
||||
|
@ -1082,7 +1082,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Dequantize"
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DeserializeIterator"
|
||||
|
Loading…
Reference in New Issue
Block a user