Added a fused instance normalization op for quantized input.
- Eigen and Neon intrinsics based implementations. - Added test suitable for regressions and standalone runs on Android. - Made it possible to compile the C++ API on android platform, so it's easy to write tests. Change: 141326710
This commit is contained in:
parent
c8cab5483f
commit
a79a7a2135
@ -10,9 +10,13 @@ licenses(["notice"]) # Apache 2.0
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
load(
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
"//tensorflow:tensorflow.bzl",
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrappers_cc")
|
"tf_cc_test",
|
||||||
|
"tf_copts",
|
||||||
|
"tf_gen_op_wrappers_cc",
|
||||||
|
"cc_library_with_android_deps",
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gradients",
|
name = "gradients",
|
||||||
@ -104,10 +108,11 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_android_deps(
|
||||||
name = "ops",
|
name = "ops",
|
||||||
srcs = ["framework/ops.cc"],
|
srcs = ["framework/ops.cc"],
|
||||||
hdrs = ["framework/ops.h"],
|
hdrs = ["framework/ops.h"],
|
||||||
|
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -117,12 +122,15 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_android_deps(
|
||||||
name = "scope",
|
name = "scope",
|
||||||
srcs = ["framework/scope.cc"],
|
srcs = ["framework/scope.cc"],
|
||||||
hdrs = ["framework/scope.h"],
|
hdrs = ["framework/scope.h"],
|
||||||
deps = [
|
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||||
|
common_deps = [
|
||||||
":ops",
|
":ops",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -143,13 +151,16 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_android_deps(
|
||||||
name = "client_session",
|
name = "client_session",
|
||||||
srcs = ["client/client_session.cc"],
|
srcs = ["client/client_session.cc"],
|
||||||
hdrs = ["client/client_session.h"],
|
hdrs = ["client/client_session.h"],
|
||||||
deps = [
|
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||||
|
common_deps = [
|
||||||
":ops",
|
":ops",
|
||||||
":scope",
|
":scope",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -173,13 +184,18 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_android_deps(
|
||||||
name = "const_op",
|
name = "const_op",
|
||||||
srcs = ["ops/const_op.cc"],
|
srcs = ["ops/const_op.cc"],
|
||||||
hdrs = ["ops/const_op.h"],
|
hdrs = ["ops/const_op.h"],
|
||||||
deps = [
|
android_deps = [
|
||||||
|
"//tensorflow/core:android_tensorflow_lib",
|
||||||
|
],
|
||||||
|
common_deps = [
|
||||||
":ops",
|
":ops",
|
||||||
":scope",
|
":scope",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
],
|
],
|
||||||
@ -354,13 +370,16 @@ tf_gen_op_wrappers_cc(
|
|||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library_with_android_deps(
|
||||||
name = "cc_op_gen_main",
|
name = "cc_op_gen_main",
|
||||||
srcs = [
|
srcs = [
|
||||||
"framework/cc_op_gen.cc",
|
"framework/cc_op_gen.cc",
|
||||||
"framework/cc_op_gen.h",
|
"framework/cc_op_gen.h",
|
||||||
"framework/cc_op_gen_main.cc",
|
"framework/cc_op_gen_main.cc",
|
||||||
],
|
],
|
||||||
|
android_deps = [
|
||||||
|
"//tensorflow/core:android_tensorflow_lib",
|
||||||
|
],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -172,6 +172,7 @@ tensorflow/core/kernels/quantized_batch_norm_op.cc
|
|||||||
tensorflow/core/kernels/quantized_bias_add_op.cc
|
tensorflow/core/kernels/quantized_bias_add_op.cc
|
||||||
tensorflow/core/kernels/quantized_concat_op.cc
|
tensorflow/core/kernels/quantized_concat_op.cc
|
||||||
tensorflow/core/kernels/quantized_conv_ops.cc
|
tensorflow/core/kernels/quantized_conv_ops.cc
|
||||||
|
tensorflow/core/kernels/quantized_instance_norm.cc
|
||||||
tensorflow/core/kernels/quantized_matmul_op.cc
|
tensorflow/core/kernels/quantized_matmul_op.cc
|
||||||
tensorflow/core/kernels/quantized_pooling_ops.cc
|
tensorflow/core/kernels/quantized_pooling_ops.cc
|
||||||
tensorflow/core/kernels/quantized_reshape_op.cc
|
tensorflow/core/kernels/quantized_reshape_op.cc
|
||||||
|
@ -3569,6 +3569,7 @@ filegroup(
|
|||||||
"quantized_bias_add_op.cc",
|
"quantized_bias_add_op.cc",
|
||||||
"quantized_concat_op.cc",
|
"quantized_concat_op.cc",
|
||||||
"quantized_conv_ops.cc",
|
"quantized_conv_ops.cc",
|
||||||
|
"quantized_instance_norm.cc",
|
||||||
"quantized_matmul_op.cc",
|
"quantized_matmul_op.cc",
|
||||||
"quantized_pooling_ops.cc",
|
"quantized_pooling_ops.cc",
|
||||||
"quantized_reshape_op.cc",
|
"quantized_reshape_op.cc",
|
||||||
@ -3672,6 +3673,7 @@ tf_kernel_library(
|
|||||||
"quantized_bias_add_op.cc",
|
"quantized_bias_add_op.cc",
|
||||||
"quantized_concat_op.cc",
|
"quantized_concat_op.cc",
|
||||||
"quantized_conv_ops.cc",
|
"quantized_conv_ops.cc",
|
||||||
|
"quantized_instance_norm.cc",
|
||||||
"quantized_matmul_op.cc",
|
"quantized_matmul_op.cc",
|
||||||
"quantized_pooling_ops.cc",
|
"quantized_pooling_ops.cc",
|
||||||
"quantized_reshape_op.cc",
|
"quantized_reshape_op.cc",
|
||||||
@ -3944,6 +3946,51 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Android-only test for quantized instance norm.
|
||||||
|
cc_binary(
|
||||||
|
name = "quantized_instance_norm_test_android_only",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["quantized_instance_norm_test.cc"],
|
||||||
|
linkopts = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"-pie",
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
linkstatic = 1,
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":android_tensorflow_kernels",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:client_session",
|
||||||
|
"//tensorflow/core:android_tensorflow_test_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "quantized_instance_norm_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["quantized_instance_norm_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
":quantized_ops",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:client_session",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:direct_session",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Google-internal targets. These must be at the end for syncrepo.
|
# Google-internal targets. These must be at the end for syncrepo.
|
||||||
|
|
||||||
|
409
tensorflow/core/kernels/quantized_instance_norm.cc
Normal file
409
tensorflow/core/kernels/quantized_instance_norm.cc
Normal file
@ -0,0 +1,409 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
|
||||||
|
#define USE_NEON
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||||
|
|
||||||
|
#ifdef USE_NEON
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Single pass mean and variance.
|
||||||
|
// Shape of `input` is [rows x cols], shape of both `mean` and `variance`
|
||||||
|
// is [cols].
|
||||||
|
// Note, `mean` and `variance` are of 'i' (not scaled).
|
||||||
|
// The following is a straightforward implementation of the parallel algorithm
|
||||||
|
// described in
|
||||||
|
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||||
|
void ColMeanAndVariance(const uint8_t* input, const uint32_t rows,
|
||||||
|
const uint32_t cols, float* mean, float* variance) {
|
||||||
|
// The implementation operates on for 16 columns at a time.
|
||||||
|
// Assumes cols % 16 == 0
|
||||||
|
for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
|
||||||
|
// Vector registers to track the running sum across the rows. Since there
|
||||||
|
// are 16 columns, we have 4 32x4 registers.
|
||||||
|
uint32x4_t sum[4] = {0};
|
||||||
|
|
||||||
|
float nA = 0.0f;
|
||||||
|
// Running average and the second moment.
|
||||||
|
float32x4_t xA[4] = {0.0f};
|
||||||
|
float32x4_t M2A[4] = {0.0f};
|
||||||
|
|
||||||
|
const uint8_t* inp_ptr = input + col_offset;
|
||||||
|
// Go over the rows in chunks of 256. This is so that we can use 16 bit adds
|
||||||
|
// to do the accumulation.
|
||||||
|
for (uint32_t row = 0; row < rows; row += 256) {
|
||||||
|
// Running sum and sum of squares for the 256 rows.
|
||||||
|
uint32x4_t sub_sum[4] = {0};
|
||||||
|
uint32x4_t sub_sq_sum[4] = {0};
|
||||||
|
const uint32_t limit = std::min(rows, row + 256);
|
||||||
|
const float nB = limit - row;
|
||||||
|
for (uint32_t subrow = row; subrow < limit; ++subrow) {
|
||||||
|
const uint8x16_t v = vld1q_u8(inp_ptr);
|
||||||
|
inp_ptr += cols;
|
||||||
|
|
||||||
|
const uint8x8_t v_high = vget_high_u8(v);
|
||||||
|
const uint8x8_t v_low = vget_low_u8(v);
|
||||||
|
|
||||||
|
const uint16x8_t v_high_u16 = vmovl_u8(v_high);
|
||||||
|
const uint16x8_t v_low_u16 = vmovl_u8(v_low);
|
||||||
|
|
||||||
|
const uint16x4_t v_high_high = vget_high_u16(v_high_u16);
|
||||||
|
const uint16x4_t v_high_low = vget_low_u16(v_high_u16);
|
||||||
|
const uint16x4_t v_low_high = vget_high_u16(v_low_u16);
|
||||||
|
const uint16x4_t v_low_low = vget_low_u16(v_low_u16);
|
||||||
|
|
||||||
|
sub_sum[0] = vaddw_u16(sub_sum[0], v_high_high);
|
||||||
|
sub_sum[1] = vaddw_u16(sub_sum[1], v_high_low);
|
||||||
|
sub_sum[2] = vaddw_u16(sub_sum[2], v_low_high);
|
||||||
|
sub_sum[3] = vaddw_u16(sub_sum[3], v_low_low);
|
||||||
|
|
||||||
|
sub_sq_sum[0] = vmlal_u16(sub_sq_sum[0], v_high_high, v_high_high);
|
||||||
|
sub_sq_sum[1] = vmlal_u16(sub_sq_sum[1], v_high_low, v_high_low);
|
||||||
|
sub_sq_sum[2] = vmlal_u16(sub_sq_sum[2], v_low_high, v_low_high);
|
||||||
|
sub_sq_sum[3] = vmlal_u16(sub_sq_sum[3], v_low_low, v_low_low);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the full running sum and moment from the ones for 256 rows.
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
sum[i] = vaddq_u32(sum[i], sub_sum[i]);
|
||||||
|
const float nX = nA + nB;
|
||||||
|
// xB is the average of up to 256 elements.
|
||||||
|
const float32x4_t xB =
|
||||||
|
vmulq_n_f32(vcvtq_f32_u32(sub_sum[i]), 1.0f / nB);
|
||||||
|
|
||||||
|
// delta = xB - xA
|
||||||
|
const float32x4_t delta = vsubq_f32(xB, xA[i]);
|
||||||
|
// xA = (nA * xA + nB * xB) / (nA + nB)
|
||||||
|
xA[i] = vmulq_n_f32(
|
||||||
|
vaddq_f32(vmulq_n_f32(xA[i], nA), vmulq_n_f32(xB, nB)), 1.0f / nX);
|
||||||
|
|
||||||
|
const float32x4_t sub_sum_f32 = vcvtq_f32_u32(sub_sum[i]);
|
||||||
|
const float32x4_t sub_sum_sq = vmulq_f32(sub_sum_f32, sub_sum_f32);
|
||||||
|
|
||||||
|
// M2B = sum(xB^2) - sum(xB)^2/nB
|
||||||
|
const float32x4_t M2B = vsubq_f32(vcvtq_f32_u32(sub_sq_sum[i]),
|
||||||
|
vmulq_n_f32(sub_sum_sq, 1.0f / nB));
|
||||||
|
const float32x4_t last_term =
|
||||||
|
vmulq_n_f32(vmulq_f32(delta, delta), nA * nB / nX);
|
||||||
|
// M2A = oldM2A + M2B + delta^2 * nA*nB/nX
|
||||||
|
M2A[i] = vaddq_f32(vaddq_f32(M2A[i], M2B), last_term);
|
||||||
|
}
|
||||||
|
nA += limit;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the final mean and variance for the 16 columns.
|
||||||
|
const float inv_rows = 1.0f / static_cast<float>(rows);
|
||||||
|
vst1q_f32(mean + col_offset, vmulq_n_f32(vcvtq_f32_u32(sum[3]), inv_rows));
|
||||||
|
vst1q_f32(mean + col_offset + 4,
|
||||||
|
vmulq_n_f32(vcvtq_f32_u32(sum[2]), inv_rows));
|
||||||
|
vst1q_f32(mean + col_offset + 8,
|
||||||
|
vmulq_n_f32(vcvtq_f32_u32(sum[1]), inv_rows));
|
||||||
|
vst1q_f32(mean + col_offset + 12,
|
||||||
|
vmulq_n_f32(vcvtq_f32_u32(sum[0]), inv_rows));
|
||||||
|
|
||||||
|
vst1q_f32(variance + col_offset, vmulq_n_f32(M2A[3], inv_rows));
|
||||||
|
vst1q_f32(variance + col_offset + 4, vmulq_n_f32(M2A[2], inv_rows));
|
||||||
|
vst1q_f32(variance + col_offset + 8, vmulq_n_f32(M2A[1], inv_rows));
|
||||||
|
vst1q_f32(variance + col_offset + 12, vmulq_n_f32(M2A[0], inv_rows));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute min and max of (input - mean) / sqrt(variance + epsilon).
|
||||||
|
// This is done in a separate pass so that the normalized value can be
|
||||||
|
// temporarily computed in floating point precision and not stored anywhere.
|
||||||
|
void MinAndMax(const uint8_t* input, const uint32_t rows, const uint32_t cols,
|
||||||
|
const float* mean_ptr, const float* variance_ptr,
|
||||||
|
float variance_epsilon, float* minimum, float* maximum) {
|
||||||
|
float v_maximum = std::numeric_limits<float>::min();
|
||||||
|
float v_minimum = std::numeric_limits<float>::max();
|
||||||
|
const float32x4_t eps = vdupq_n_f32(variance_epsilon);
|
||||||
|
|
||||||
|
for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
|
||||||
|
const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset),
|
||||||
|
vld1q_f32(mean_ptr + col_offset + 4),
|
||||||
|
vld1q_f32(mean_ptr + col_offset + 8),
|
||||||
|
vld1q_f32(mean_ptr + col_offset + 12)};
|
||||||
|
const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset),
|
||||||
|
vld1q_f32(variance_ptr + col_offset + 4),
|
||||||
|
vld1q_f32(variance_ptr + col_offset + 8),
|
||||||
|
vld1q_f32(variance_ptr + col_offset + 12)};
|
||||||
|
const float32x4_t inv_stddev[4] = {
|
||||||
|
vrsqrteq_f32(vaddq_f32(variance[0], eps)),
|
||||||
|
vrsqrteq_f32(vaddq_f32(variance[1], eps)),
|
||||||
|
vrsqrteq_f32(vaddq_f32(variance[2], eps)),
|
||||||
|
vrsqrteq_f32(vaddq_f32(variance[3], eps))};
|
||||||
|
|
||||||
|
const uint8_t* inp_ptr = input + col_offset;
|
||||||
|
for (uint32_t row = 0; row < rows; ++row) {
|
||||||
|
const uint8x16_t v = vld1q_u8(inp_ptr);
|
||||||
|
inp_ptr += cols;
|
||||||
|
|
||||||
|
const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
|
||||||
|
const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
|
||||||
|
|
||||||
|
const float32x4_t v_float[4] = {
|
||||||
|
vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
|
||||||
|
vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
|
||||||
|
vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
|
||||||
|
vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
const float32x4_t normed =
|
||||||
|
vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
|
||||||
|
const float32x2_t high = vget_high_f32(normed);
|
||||||
|
const float32x2_t low = vget_low_f32(normed);
|
||||||
|
float32x2_t tmp_max = vpmax_f32(low, high);
|
||||||
|
tmp_max = vpmax_f32(tmp_max, tmp_max);
|
||||||
|
v_maximum = std::max(v_maximum, vget_lane_f32(tmp_max, 0));
|
||||||
|
float32x2_t tmp_min = vpmin_f32(low, high);
|
||||||
|
tmp_min = vpmin_f32(tmp_min, tmp_min);
|
||||||
|
v_minimum = std::min(v_minimum, vget_lane_f32(tmp_min, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*minimum = v_minimum;
|
||||||
|
*maximum = v_maximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute (input - mean) / sqrt(variance + epsilon) in floating point, quantize
|
||||||
|
// it in the range (minimum, maximum) and store the result as quint8.
|
||||||
|
void InstanceNorm(const uint8_t* input, const uint32_t rows,
|
||||||
|
const uint32_t cols, const float* mean_ptr,
|
||||||
|
const float* variance_ptr, float variance_epsilon,
|
||||||
|
float minimum, float maximum, uint8_t* output) {
|
||||||
|
const float32x4_t eps = vdupq_n_f32(variance_epsilon);
|
||||||
|
const float32x4_t out_min = vdupq_n_f32(minimum);
|
||||||
|
const float out_scale = 255.0f / (maximum - minimum);
|
||||||
|
|
||||||
|
for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
|
||||||
|
const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset + 12),
|
||||||
|
vld1q_f32(mean_ptr + col_offset + 8),
|
||||||
|
vld1q_f32(mean_ptr + col_offset + 4),
|
||||||
|
vld1q_f32(mean_ptr + col_offset)};
|
||||||
|
const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset + 12),
|
||||||
|
vld1q_f32(variance_ptr + col_offset + 8),
|
||||||
|
vld1q_f32(variance_ptr + col_offset + 4),
|
||||||
|
vld1q_f32(variance_ptr + col_offset)};
|
||||||
|
const float32x4_t inv_stddev[4] = {
|
||||||
|
vrsqrteq_f32(vaddq_f32(variance[0], eps)),
|
||||||
|
vrsqrteq_f32(vaddq_f32(variance[1], eps)),
|
||||||
|
vrsqrteq_f32(vaddq_f32(variance[2], eps)),
|
||||||
|
vrsqrteq_f32(vaddq_f32(variance[3], eps))};
|
||||||
|
const uint8_t* inp_ptr = input + col_offset;
|
||||||
|
uint8_t* out_ptr = output + col_offset;
|
||||||
|
for (uint32_t row = 0; row < rows; ++row) {
|
||||||
|
const uint8x16_t v = vld1q_u8(inp_ptr);
|
||||||
|
inp_ptr += cols;
|
||||||
|
const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
|
||||||
|
const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
|
||||||
|
|
||||||
|
const float32x4_t v_float[4] = {
|
||||||
|
vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
|
||||||
|
vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
|
||||||
|
vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
|
||||||
|
vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
|
||||||
|
|
||||||
|
uint16x4_t normed_uint16[4];
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
const float32x4_t normed =
|
||||||
|
vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
|
||||||
|
const int32x4_t normed_int32 =
|
||||||
|
vcvtq_s32_f32(vmulq_n_f32(vsubq_f32(normed, out_min), out_scale));
|
||||||
|
normed_uint16[i] = vqmovun_s32(normed_int32);
|
||||||
|
}
|
||||||
|
vst1_u8(out_ptr,
|
||||||
|
vqmovn_u16(vcombine_u16(normed_uint16[3], normed_uint16[2])));
|
||||||
|
vst1_u8(out_ptr + 8,
|
||||||
|
vqmovn_u16(vcombine_u16(normed_uint16[1], normed_uint16[0])));
|
||||||
|
out_ptr += cols;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace
|
||||||
|
#endif // USE_NEON
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
class QuantizedInstanceNorm : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit QuantizedInstanceNorm(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("variance_epsilon", &variance_epsilon_));
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("min_separation", &min_separation_));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->GetAttr("output_range_given", &output_range_given_));
|
||||||
|
if (output_range_given_) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("given_y_min", &given_y_min_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("given_y_max", &given_y_max_));
|
||||||
|
OP_REQUIRES(context, given_y_min_ < given_y_max_,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"given_y_min must be less than given_y_max : ",
|
||||||
|
given_y_min_, " >= ", given_y_max_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor& input = context->input(0);
|
||||||
|
|
||||||
|
float input_min = context->input(1).flat<float>()(0);
|
||||||
|
float input_max = context->input(2).flat<float>()(0);
|
||||||
|
float input_scale = (input_max - input_min) / 255.0f;
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, input_min < input_max,
|
||||||
|
errors::InvalidArgument("input_min must be less than input_max : ",
|
||||||
|
input_min, " >= ", input_max));
|
||||||
|
|
||||||
|
auto input_tensor = input.tensor<quint8, 4>();
|
||||||
|
auto N = input_tensor.dimension(0);
|
||||||
|
auto H = input_tensor.dimension(1);
|
||||||
|
auto W = input_tensor.dimension(2);
|
||||||
|
auto C = input_tensor.dimension(3);
|
||||||
|
|
||||||
|
Tensor* output = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(0, input.shape(), &output));
|
||||||
|
|
||||||
|
Tensor* output_min = nullptr;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
|
||||||
|
Tensor* output_max = nullptr;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
|
||||||
|
|
||||||
|
typedef TTypes<float>::Tensor::Index Index;
|
||||||
|
|
||||||
|
#if defined(EIGEN_HAS_INDEX_LIST)
|
||||||
|
const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>>
|
||||||
|
reduction_indices;
|
||||||
|
Eigen::IndexList<Eigen::type2index<1>, Index, Index, Eigen::type2index<1>>
|
||||||
|
broadcast_spec;
|
||||||
|
broadcast_spec.set(1, H);
|
||||||
|
broadcast_spec.set(2, W);
|
||||||
|
Eigen::IndexList<Index, Eigen::type2index<1>, Eigen::type2index<1>, Index>
|
||||||
|
expand_spec;
|
||||||
|
expand_spec.set(0, N);
|
||||||
|
expand_spec.set(3, C);
|
||||||
|
#else
|
||||||
|
const Eigen::array<Index, 2> reduction_indices{1, 2};
|
||||||
|
const Eigen::array<Index, 4> broadcast_spec{1, H, W, 1};
|
||||||
|
const Eigen::array<Index, 4> expand_spec{N, 1, 1, C};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Eigen::Tensor<float, 2, Eigen::RowMajor> float_mean(N, C);
|
||||||
|
Eigen::Tensor<float, 2, Eigen::RowMajor> float_variance(N, C);
|
||||||
|
|
||||||
|
#ifdef USE_NEON
|
||||||
|
if (N == 1 && (C % 16 == 0)) {
|
||||||
|
VLOG(2) << "Calling optimized";
|
||||||
|
ColMeanAndVariance(reinterpret_cast<const uint8_t*>(input_tensor.data()),
|
||||||
|
H * W, C, float_mean.data(), float_variance.data());
|
||||||
|
|
||||||
|
float minimum = given_y_min_, maximum = given_y_max_;
|
||||||
|
if (!output_range_given_) {
|
||||||
|
MinAndMax(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
|
||||||
|
C, float_mean.data(), float_variance.data(),
|
||||||
|
variance_epsilon_, &minimum, &maximum);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maximum - minimum < min_separation_) {
|
||||||
|
maximum = minimum + min_separation_;
|
||||||
|
}
|
||||||
|
|
||||||
|
InstanceNorm(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
|
||||||
|
C, float_mean.data(), float_variance.data(),
|
||||||
|
variance_epsilon_, minimum, maximum,
|
||||||
|
reinterpret_cast<uint8_t*>(output->flat<quint8>().data()));
|
||||||
|
output_min->scalar<float>()() = minimum;
|
||||||
|
output_max->scalar<float>()() = maximum;
|
||||||
|
} else // NOLINT(readability/braces)
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
VLOG(2) << "Calling unoptimized";
|
||||||
|
float_mean = input_tensor.cast<float>().reduce(
|
||||||
|
reduction_indices, Eigen::internal::MeanReducer<float>());
|
||||||
|
|
||||||
|
float_variance =
|
||||||
|
(input_scale *
|
||||||
|
((input_tensor.cast<float>() -
|
||||||
|
float_mean.reshape(expand_spec).broadcast(broadcast_spec))))
|
||||||
|
.square()
|
||||||
|
.reduce(reduction_indices, Eigen::internal::MeanReducer<float>());
|
||||||
|
|
||||||
|
Eigen::Tensor<float, 4, Eigen::RowMajor> instance_normed =
|
||||||
|
input_scale *
|
||||||
|
(input_tensor.cast<float>() -
|
||||||
|
float_mean.reshape(expand_spec).broadcast(broadcast_spec)) *
|
||||||
|
(float_variance + variance_epsilon_)
|
||||||
|
.rsqrt()
|
||||||
|
.reshape(expand_spec)
|
||||||
|
.broadcast(broadcast_spec);
|
||||||
|
|
||||||
|
Eigen::Tensor<float, 0, Eigen::RowMajor> normed_min;
|
||||||
|
Eigen::Tensor<float, 0, Eigen::RowMajor> normed_max;
|
||||||
|
|
||||||
|
if (!output_range_given_) {
|
||||||
|
normed_min = instance_normed.minimum();
|
||||||
|
normed_max = instance_normed.maximum();
|
||||||
|
} else {
|
||||||
|
normed_min() = given_y_min_;
|
||||||
|
normed_max() = given_y_max_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normed_max() - normed_min() < min_separation_) {
|
||||||
|
normed_max() = normed_min() + min_separation_;
|
||||||
|
}
|
||||||
|
|
||||||
|
FloatToQuantizedStruct<quint8> output_f2q(normed_min(), normed_max());
|
||||||
|
auto instance_normed_quantized =
|
||||||
|
QUANTIZE_WITH_EIGEN(instance_normed, output_f2q, quint8);
|
||||||
|
|
||||||
|
output->tensor<quint8, 4>().device(
|
||||||
|
context->template eigen_device<CPUDevice>()) =
|
||||||
|
instance_normed_quantized;
|
||||||
|
output_min->flat<float>()(0) = normed_min();
|
||||||
|
output_max->flat<float>()(0) = normed_max();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
float variance_epsilon_;
|
||||||
|
float min_separation_;
|
||||||
|
bool output_range_given_;
|
||||||
|
float given_y_min_;
|
||||||
|
float given_y_max_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("QuantizedInstanceNorm")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<quint8>("T"),
|
||||||
|
QuantizedInstanceNorm);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
202
tensorflow/core/kernels/quantized_instance_norm_test.cc
Normal file
202
tensorflow/core/kernels/quantized_instance_norm_test.cc
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/client/client_session.h"
|
||||||
|
#include "tensorflow/cc/ops/array_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
void ReferenceImpl(const quint8* inp, float inp_min, float inp_max,
|
||||||
|
const TensorShape& shape, float var_eps, float* out) {
|
||||||
|
int N = shape.dim_size(0);
|
||||||
|
int H = shape.dim_size(1);
|
||||||
|
int W = shape.dim_size(2);
|
||||||
|
int C = shape.dim_size(3);
|
||||||
|
|
||||||
|
int total = N * H * W * C;
|
||||||
|
float inp_scale = (inp_max - inp_min) / 255.0f;
|
||||||
|
std::unique_ptr<float[]> dequantized(new float[total]);
|
||||||
|
|
||||||
|
for (int i = 0; i < total; ++i) {
|
||||||
|
dequantized[i] = inp_min + inp_scale * static_cast<float>(inp[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<float[]> inp_mean(new float[N * C]);
|
||||||
|
std::unique_ptr<float[]> inp_var(new float[N * C]);
|
||||||
|
|
||||||
|
float img_size = static_cast<float>(H) * static_cast<float>(W);
|
||||||
|
|
||||||
|
// Compute mean
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
for (int c = 0; c < C; ++c) {
|
||||||
|
float sum = 0.0;
|
||||||
|
for (int i = 0; i < H * W; ++i) {
|
||||||
|
sum += dequantized[n * H * W * C + i * C + c];
|
||||||
|
}
|
||||||
|
inp_mean[n * C + c] = sum / img_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute var
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
for (int c = 0; c < C; ++c) {
|
||||||
|
float sum = 0.0;
|
||||||
|
for (int i = 0; i < H * W; ++i) {
|
||||||
|
float tmp =
|
||||||
|
dequantized[n * H * W * C + i * C + c] - inp_mean[n * C + c];
|
||||||
|
sum += tmp * tmp;
|
||||||
|
}
|
||||||
|
inp_var[n * C + c] = sum / img_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
for (int c = 0; c < C; ++c) {
|
||||||
|
for (int i = 0; i < H * W; ++i) {
|
||||||
|
out[n * H * W * C + i * C + c] =
|
||||||
|
(dequantized[n * H * W * C + i * C + c] - inp_mean[n * C + c]) /
|
||||||
|
std::sqrt(inp_var[n * C + c] + var_eps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void Expect(const Tensor& input, float x_min, float x_max,
|
||||||
|
bool output_range_given, float give_y_min, float given_y_max) {
|
||||||
|
Scope root = Scope::NewRootScope();
|
||||||
|
|
||||||
|
auto input_ph = Placeholder(root, DT_QUINT8);
|
||||||
|
|
||||||
|
const float variance_eps = 1e-5;
|
||||||
|
auto instance_norm = QuantizedInstanceNorm(
|
||||||
|
root, input_ph, x_min, x_max,
|
||||||
|
QuantizedInstanceNorm::Attrs().VarianceEpsilon(variance_eps));
|
||||||
|
|
||||||
|
Status s = root.status();
|
||||||
|
EXPECT_TRUE(s.ok());
|
||||||
|
|
||||||
|
ClientSession session(root);
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
|
||||||
|
s = session.Run({{input_ph, input}},
|
||||||
|
{instance_norm.y, instance_norm.y_min, instance_norm.y_max},
|
||||||
|
&outputs);
|
||||||
|
|
||||||
|
EXPECT_TRUE(s.ok());
|
||||||
|
Tensor expected(DT_FLOAT, input.shape());
|
||||||
|
|
||||||
|
ReferenceImpl(input.flat<quint8>().data(), x_min, x_max, input.shape(),
|
||||||
|
variance_eps, expected.flat<float>().data());
|
||||||
|
|
||||||
|
auto out = outputs[0].flat<quint8>();
|
||||||
|
|
||||||
|
float out_min = outputs[1].flat<float>()(0);
|
||||||
|
float out_max = outputs[2].flat<float>()(0);
|
||||||
|
float out_scale = (out_max - out_min) / 255.0f;
|
||||||
|
|
||||||
|
Eigen::Tensor<float, 0, Eigen::RowMajor> max_diff =
|
||||||
|
(expected.flat<float>() - (out_min + out_scale * out.cast<float>()))
|
||||||
|
.abs()
|
||||||
|
.maximum();
|
||||||
|
EXPECT_LE(max_diff(), 0.1);
|
||||||
|
LOG(INFO) << "max diff " << max_diff();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
|
void TestBasic() {
|
||||||
|
Tensor input_tensor(DT_QUINT8, {1, 4, 4, 32});
|
||||||
|
auto input = input_tensor.flat<quint8>();
|
||||||
|
// Random input
|
||||||
|
input = input.random(Eigen::internal::UniformRandomGenerator<quint8>());
|
||||||
|
|
||||||
|
Expect(input_tensor, 0.0f, 1.0f, false, 0.0f, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestZeroInput() {
|
||||||
|
Tensor input_tensor(DT_QUINT8, {1, 4, 4, 32});
|
||||||
|
auto input = input_tensor.flat<quint8>();
|
||||||
|
// Zero input, but input min > 0. Tests that output min and max should be
|
||||||
|
// properly separated.
|
||||||
|
input = input.setConstant(0);
|
||||||
|
|
||||||
|
Expect(input_tensor, 2.0f, 3.0f, false, 0.0f, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestMaxInput() {
|
||||||
|
Tensor input_tensor(DT_QUINT8, {1, 1, 2, 16});
|
||||||
|
auto input = input_tensor.flat<quint8>();
|
||||||
|
// Inputs are all FLT_MAX / (number of inputs).
|
||||||
|
input = input.setConstant(255);
|
||||||
|
|
||||||
|
Expect(input_tensor, 0.0f,
|
||||||
|
std::numeric_limits<float>::max() / static_cast<float>(2 * 16), false,
|
||||||
|
0.0f, 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestOutputRangeGiven() {
|
||||||
|
Tensor input_tensor(DT_QUINT8, {1, 4, 4, 32});
|
||||||
|
auto input = input_tensor.flat<quint8>();
|
||||||
|
input = input.random(Eigen::internal::UniformRandomGenerator<quint8>());
|
||||||
|
|
||||||
|
Expect(input_tensor, -10.0f, 10.0f, true, -1.0f, 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestClamp() {
|
||||||
|
Tensor input_tensor(DT_QUINT8, {1, 4, 4, 32});
|
||||||
|
auto input = input_tensor.flat<quint8>();
|
||||||
|
input = input.random(Eigen::internal::UniformRandomGenerator<quint8>());
|
||||||
|
|
||||||
|
// Tests that negative outputs are clamped at 0.0, as the output range is
|
||||||
|
// given to be (0.0, 1.0).
|
||||||
|
Expect(input_tensor, -10.0f, 10.0f, true, 0.0f, 1.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined(__ANDROID__)
|
||||||
|
|
||||||
|
#define RUN_TEST(t) \
|
||||||
|
TEST(QuantizedInstanceNormTest, t) { t(); }
|
||||||
|
|
||||||
|
RUN_TEST(TestBasic);
|
||||||
|
RUN_TEST(TestZeroInput);
|
||||||
|
RUN_TEST(TestMaxInput);
|
||||||
|
RUN_TEST(TestOutputRangeGiven);
|
||||||
|
RUN_TEST(TestClamp);
|
||||||
|
|
||||||
|
#undef RUN_TEST
|
||||||
|
|
||||||
|
#endif // __ANDROID__
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#if defined(__ANDROID__)
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
tensorflow::TestBasic();
|
||||||
|
tensorflow::TestZeroInput();
|
||||||
|
tensorflow::TestMaxInput();
|
||||||
|
tensorflow::TestOutputRangeGiven();
|
||||||
|
tensorflow::TestClamp();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#endif // __ANDROID__
|
@ -4430,6 +4430,51 @@ output_min: This value is copied from input_min.
|
|||||||
output_max: This value is copied from input_max.
|
output_max: This value is copied from input_max.
|
||||||
)Doc");
|
)Doc");
|
||||||
|
|
||||||
|
REGISTER_OP("QuantizedInstanceNorm")
|
||||||
|
.Input("x: T")
|
||||||
|
.Input("x_min: float")
|
||||||
|
.Input("x_max: float")
|
||||||
|
.Output("y: T")
|
||||||
|
.Output("y_min: float")
|
||||||
|
.Output("y_max: float")
|
||||||
|
.Attr("T: quantizedtype")
|
||||||
|
.Attr("output_range_given: bool = false")
|
||||||
|
.Attr("given_y_min: float = 0")
|
||||||
|
.Attr("given_y_max: float = 0")
|
||||||
|
.Attr("variance_epsilon: float = 1e-5")
|
||||||
|
.Attr("min_separation: float = 1e-3")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle unused;
|
||||||
|
// x should be a rank 4 tensor.
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &unused));
|
||||||
|
// Assert x_min and x_max are scalars (rank 0).
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||||
|
// y has the same shape as x.
|
||||||
|
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
|
||||||
|
// y_min and y_max are scalars.
|
||||||
|
c->set_output(1, c->Scalar());
|
||||||
|
c->set_output(2, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"doc(
|
||||||
|
Quantized Instance normalization.
|
||||||
|
|
||||||
|
x: A 4D input Tensor.
|
||||||
|
x_min: The value represented by the lowest quantized input.
|
||||||
|
x_max: The value represented by the highest quantized input.
|
||||||
|
y: A 4D Tensor.
|
||||||
|
y_min: The value represented by the lowest quantized output.
|
||||||
|
y_max: The value represented by the highest quantized output.
|
||||||
|
output_range_given: If True, `given_y_min` and `given_y_min`
|
||||||
|
and `given_y_max` are used as the output range. Otherwise,
|
||||||
|
the implementation computes the output range.
|
||||||
|
given_y_min: Output in `y_min` if `output_range_given` is True.
|
||||||
|
given_y_max: Output in `y_max` if `output_range_given` is True.
|
||||||
|
variance_epsilon: A small float number to avoid dividing by 0.
|
||||||
|
min_separation: Minimum value of `y_max - y_min`
|
||||||
|
)doc");
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Status ScatterNdShape(InferenceContext* c) {
|
Status ScatterNdShape(InferenceContext* c) {
|
||||||
|
@ -222,12 +222,14 @@ def tf_gen_op_wrappers_cc(name,
|
|||||||
native.cc_library(name=name,
|
native.cc_library(name=name,
|
||||||
srcs=subsrcs,
|
srcs=subsrcs,
|
||||||
hdrs=subhdrs,
|
hdrs=subhdrs,
|
||||||
deps=deps + [
|
deps=deps + if_not_android([
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
]) + if_android([
|
||||||
|
"//tensorflow/core:android_tensorflow_lib",
|
||||||
|
]),
|
||||||
copts=tf_copts(),
|
copts=tf_copts(),
|
||||||
alwayslink=1,
|
alwayslink=1,
|
||||||
visibility=visibility)
|
visibility=visibility)
|
||||||
@ -911,3 +913,8 @@ def tf_version_info_genrule():
|
|||||||
local = 1,
|
local = 1,
|
||||||
tools = ["//tensorflow/tools/git:gen_git_source.py"],
|
tools = ["//tensorflow/tools/git:gen_git_source.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def cc_library_with_android_deps(deps, android_deps=[],
|
||||||
|
common_deps=[], **kwargs):
|
||||||
|
deps = if_not_android(deps) + if_android(android_deps) + common_deps
|
||||||
|
native.cc_library(deps=deps, **kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user