Improve the performance of resize_bilinear.
The previous implementation of resize_bilinear operation is a straight forward, naive implementation. Unfortunately, it leaves a significant amount of room for improvement. This change improves the performance of the operation by avoiding unnecessary recomputation (for where the operation was CPU bound), and avoids redundant memory references (for where the operation was memory bound). In my benchmark (loosely based on inception image pipeline), this implementation change improves speed between 1.5-2x. Additionally, to ensure correctness, I've preserved the old implementation of resize_bilinear, and added a number of tests that initalize a random image, and ensure that the outputs are identical. Change: 144726607
This commit is contained in:
parent
d83ae93213
commit
e9602d2752
@ -1669,6 +1669,7 @@ tf_cc_tests(
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -90,6 +90,18 @@ struct ImageResizerState {
|
||||
errors::InvalidArgument("input image must be of non-zero size"));
|
||||
height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
|
||||
width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
|
||||
|
||||
// Guard against overflows
|
||||
OP_REQUIRES(context,
|
||||
ceilf((out_height - 1) * height_scale) <=
|
||||
static_cast<float>(std::numeric_limits<int64>::max()),
|
||||
errors::InvalidArgument(
|
||||
"input image height scale would cause an overflow"));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
ceilf((out_width - 1) * width_scale) <= static_cast<float>(INT_MAX),
|
||||
errors::InvalidArgument(
|
||||
"input image width scale would cause an overflow"));
|
||||
}
|
||||
|
||||
// Calculates all the required variables, and allocates the output.
|
||||
|
@ -55,23 +55,219 @@ class ResizeBilinearOp : public OpKernel {
|
||||
typename TTypes<float, 4>::Tensor output_data =
|
||||
st.output->tensor<float, 4>();
|
||||
|
||||
functor::ResizeBilinear<Device, T>()(context->eigen_device<Device>(),
|
||||
image_data, st.height_scale,
|
||||
st.width_scale, output_data);
|
||||
functor::ResizeBilinear<Device, T>()(
|
||||
context, context->eigen_device<Device>(), image_data, st.height_scale,
|
||||
st.width_scale, output_data);
|
||||
}
|
||||
|
||||
private:
|
||||
bool align_corners_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
// Compute the interpolation indices only once.
|
||||
struct CachedInterpolation {
|
||||
int64 lower; // Lower source index used in the interpolation
|
||||
int64 upper; // Upper source index used in the interpolation
|
||||
// 1-D linear iterpolation scale (see:
|
||||
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
|
||||
float lerp;
|
||||
// How many consecutive points use the same lower & upper indices
|
||||
int consecutive;
|
||||
};
|
||||
|
||||
enum ImageScalePattern { SCALE_UP, SIMILAR, SCALE_DOWN };
|
||||
|
||||
inline ImageScalePattern compute_image_scale_pattern(const int64 out_height,
|
||||
const int64 out_width,
|
||||
const int64 in_height,
|
||||
const int64 in_width) {
|
||||
if (in_height * 2 < out_height || in_width * 2 < out_width) {
|
||||
return SCALE_UP;
|
||||
} else if (out_height * 2 < in_height || out_width * 2 < in_width) {
|
||||
return SCALE_DOWN;
|
||||
} else {
|
||||
return SIMILAR;
|
||||
}
|
||||
}
|
||||
|
||||
inline int compute_scratch_size(const int64 out_height, const int64 out_width,
|
||||
const int64 in_height, const int64 in_width,
|
||||
const int channels,
|
||||
const ImageScalePattern scale_pattern) {
|
||||
// Allocate a CachedInterpolation for each y, and each x in the out-height,
|
||||
// plus 2 extra to avoid extra branches in the
|
||||
// CachedInterpolation.consecutive computation.
|
||||
const int cached_computation_size =
|
||||
sizeof(CachedInterpolation) * (out_height + out_width + 2);
|
||||
if (scale_pattern == SCALE_DOWN) {
|
||||
return cached_computation_size;
|
||||
} else {
|
||||
// In order to avoid paying the cost of data type conversion multiple times,
|
||||
// we must allocate a temporary image as well.
|
||||
const int tmp_image_size = sizeof(float) * in_height * in_width * channels;
|
||||
// We batch up all memory allocations into a single malloc call for
|
||||
// performance reasons.
|
||||
return cached_computation_size + tmp_image_size;
|
||||
}
|
||||
}
|
||||
|
||||
inline void compute_interpolation_weights(const ImageScalePattern scale_pattern,
|
||||
const int64 out_size,
|
||||
const int64 in_size,
|
||||
const float scale,
|
||||
CachedInterpolation* interpolation) {
|
||||
interpolation[out_size].lower = 0;
|
||||
interpolation[out_size].upper = 0;
|
||||
interpolation[out_size].consecutive = 0;
|
||||
for (int64 i = out_size - 1; i >= 0; --i) {
|
||||
const float in = i * scale;
|
||||
interpolation[i].lower = static_cast<int64>(in);
|
||||
interpolation[i].upper = std::min(interpolation[i].lower + 1, in_size - 1);
|
||||
interpolation[i].lerp = in - interpolation[i].lower;
|
||||
interpolation[i].consecutive =
|
||||
interpolation[i + 1].lower == interpolation[i].lower &&
|
||||
interpolation[i + 1].upper == interpolation[i].upper
|
||||
? interpolation[i + 1].consecutive + 1
|
||||
: 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Converter {
|
||||
static inline const float* convert_image_to_float(
|
||||
typename TTypes<T, 4>::ConstTensor images, const int batch_index,
|
||||
const int64 in_height, const int64 in_width, const int channels,
|
||||
std::vector<float>* converted_image_v) {
|
||||
converted_image_v->resize(in_height * in_width * channels);
|
||||
float* converted_image = converted_image_v->data();
|
||||
for (int64 y = 0; y < in_height; ++y) {
|
||||
for (int64 x = 0; x < in_width; ++x) {
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
converted_image[y * in_width * channels + x * channels + c] =
|
||||
static_cast<float>(images(batch_index, y, x, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
return converted_image;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<float> {
|
||||
static inline const float* convert_image_to_float(
|
||||
typename TTypes<float, 4>::ConstTensor images, const int b,
|
||||
const int64 in_height, const int64 in_width, const int channels,
|
||||
std::vector<float>* converted_image_v) {
|
||||
return images.data() + (b * in_height * in_width * channels);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Computes the bilinear interpolation from the appropriate 4 float points
|
||||
* and the linear interpolation weights.
|
||||
*/
|
||||
inline float compute_lerp(const float top_left, const float top_right,
|
||||
const float bottom_left, const float bottom_right,
|
||||
const float x_lerp, const float y_lerp) {
|
||||
const float top = top_left + (top_right - top_left) * x_lerp;
|
||||
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
|
||||
return top + (bottom - top) * y_lerp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void scale_down_image(typename TTypes<T, 4>::ConstTensor images,
|
||||
const int batch_size, const int64 out_height,
|
||||
const int64 out_width, const int channels,
|
||||
const std::vector<CachedInterpolation>& xs,
|
||||
const std::vector<CachedInterpolation>& ys,
|
||||
typename TTypes<float, 4>::Tensor output) {
|
||||
// Do not eagerly convert all input data points, as we ignore most.
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
// Compute the interpolation
|
||||
for (int64 y = 0; y < out_height; ++y) {
|
||||
for (int64 x = 0; x < out_width; ++x) {
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
const float top_left(images(b, ys[y].lower, xs[x].lower, c));
|
||||
const float top_right(images(b, ys[y].lower, xs[x].upper, c));
|
||||
const float bottom_left(images(b, ys[y].upper, xs[x].lower, c));
|
||||
const float bottom_right(images(b, ys[y].upper, xs[x].upper, c));
|
||||
output(b, y, x, c) =
|
||||
compute_lerp(top_left, top_right, bottom_left, bottom_right,
|
||||
xs[x].lerp, ys[y].lerp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void scale_up_image(const float* input_image, const int batch_index,
|
||||
const int64 out_height, const int64 out_width,
|
||||
const int channels, const int64 in_height,
|
||||
const int64 in_width,
|
||||
const std::vector<CachedInterpolation>& xs,
|
||||
const std::vector<CachedInterpolation>& ys,
|
||||
typename TTypes<float, 4>::Tensor output) {
|
||||
for (int64 y = 0; y < out_height; y += ys[y].consecutive) {
|
||||
const int64 in_y_lower = ys[y].lower * in_width * channels;
|
||||
const int64 in_y_upper = ys[y].upper * in_width * channels;
|
||||
for (int64 x = 0; x < out_width; x += xs[x].consecutive) {
|
||||
const int64 in_x_lower = xs[x].lower * channels;
|
||||
const int64 in_x_upper = xs[x].upper * channels;
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
const float top_left = input_image[in_y_lower + in_x_lower + c];
|
||||
const float top_right = input_image[in_y_lower + in_x_upper + c];
|
||||
const float bottom_left = input_image[in_y_upper + in_x_lower + c];
|
||||
const float bottom_right = input_image[in_y_upper + in_x_upper + c];
|
||||
for (int64 y_inner = y; y_inner < y + ys[y].consecutive; ++y_inner) {
|
||||
for (int64 x_inner = x; x_inner < x + xs[x].consecutive; ++x_inner) {
|
||||
output(batch_index, y_inner, x_inner, c) =
|
||||
compute_lerp(top_left, top_right, bottom_left, bottom_right,
|
||||
xs[x_inner].lerp, ys[y_inner].lerp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void scale_similar_image(const float* input_image, const int b,
|
||||
const int64 out_height, const int64 out_width,
|
||||
const int channels, const int64 in_height,
|
||||
const int64 in_width,
|
||||
const std::vector<CachedInterpolation>& xs,
|
||||
const std::vector<CachedInterpolation>& ys,
|
||||
typename TTypes<float, 4>::Tensor output) {
|
||||
// Compute the interpolation
|
||||
for (int64 y = 0; y < out_height; ++y) {
|
||||
const int64 in_y_lower = ys[y].lower * in_width * channels;
|
||||
const int64 in_y_upper = ys[y].upper * in_width * channels;
|
||||
// Similar-sized images do not have a set of inner loops.
|
||||
for (int64 x = 0; x < out_width; ++x) {
|
||||
const int64 in_x_lower = xs[x].lower * channels;
|
||||
const int64 in_x_upper = xs[x].upper * channels;
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
const float top_left = input_image[in_y_lower + in_x_lower + c];
|
||||
const float top_right = input_image[in_y_lower + in_x_upper + c];
|
||||
const float bottom_left = input_image[in_y_upper + in_x_lower + c];
|
||||
const float bottom_right = input_image[in_y_upper + in_x_upper + c];
|
||||
output(b, y, x, c) = compute_lerp(top_left, top_right, bottom_left,
|
||||
bottom_right, xs[x].lerp, ys[y].lerp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Partial specialization of ResizeBilinear functor for a CPUDevice.
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct ResizeBilinear<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor images,
|
||||
void operator()(OpKernelContext* context, const CPUDevice& d,
|
||||
typename TTypes<T, 4>::ConstTensor images,
|
||||
const float height_scale, const float width_scale,
|
||||
typename TTypes<float, 4>::Tensor output) {
|
||||
const int batch = images.dimension(0);
|
||||
const int batch_size = images.dimension(0);
|
||||
const int64 in_height = images.dimension(1);
|
||||
const int64 in_width = images.dimension(2);
|
||||
const int channels = images.dimension(3);
|
||||
@ -79,31 +275,41 @@ struct ResizeBilinear<CPUDevice, T> {
|
||||
const int64 out_height = output.dimension(1);
|
||||
const int64 out_width = output.dimension(2);
|
||||
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
for (int y = 0; y < out_height; ++y) {
|
||||
const float in_y = y * height_scale;
|
||||
const int64 top_y_index = static_cast<int64>(floorf(in_y));
|
||||
const int64 bottom_y_index =
|
||||
std::min(static_cast<int64>(ceilf(in_y)), in_height - 1);
|
||||
const float y_lerp = in_y - top_y_index;
|
||||
for (int x = 0; x < out_width; ++x) {
|
||||
const float in_x = x * width_scale;
|
||||
const int64 left_x_index = static_cast<int64>(floorf(in_x));
|
||||
const int64 right_x_index =
|
||||
std::min(static_cast<int64>(ceilf(in_x)), in_width - 1);
|
||||
const float x_lerp = in_x - left_x_index;
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
const float top_left(images(b, top_y_index, left_x_index, c));
|
||||
const float top_right(images(b, top_y_index, right_x_index, c));
|
||||
const float bottom_left(images(b, bottom_y_index, left_x_index, c));
|
||||
const float bottom_right(
|
||||
images(b, bottom_y_index, right_x_index, c));
|
||||
const float top = top_left + (top_right - top_left) * x_lerp;
|
||||
const float bottom =
|
||||
bottom_left + (bottom_right - bottom_left) * x_lerp;
|
||||
output(b, y, x, c) = top + (bottom - top) * y_lerp;
|
||||
}
|
||||
}
|
||||
// Handle no-op resizes efficiently.
|
||||
if (out_height == in_height && out_width == in_width) {
|
||||
output = images.template cast<float>();
|
||||
return;
|
||||
}
|
||||
|
||||
const ImageScalePattern scale_pattern =
|
||||
compute_image_scale_pattern(out_height, out_width, in_height, in_width);
|
||||
std::vector<CachedInterpolation> ys(out_height + 1);
|
||||
std::vector<CachedInterpolation> xs(out_width + 1);
|
||||
std::vector<float> converted_image_v;
|
||||
|
||||
// Compute the cached interpolation weights on the x and y dimensions.
|
||||
compute_interpolation_weights(scale_pattern, out_height, in_height,
|
||||
height_scale, ys.data());
|
||||
compute_interpolation_weights(scale_pattern, out_width, in_width,
|
||||
width_scale, xs.data());
|
||||
|
||||
if (scale_pattern == SCALE_UP) {
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
const float* converted_image = Converter<T>::convert_image_to_float(
|
||||
images, b, in_height, in_width, channels, &converted_image_v);
|
||||
scale_up_image(converted_image, b, out_height, out_width, channels,
|
||||
in_height, in_width, xs, ys, output);
|
||||
}
|
||||
} else if (scale_pattern == SCALE_DOWN) {
|
||||
// Do not eagerly convert all input data points, as we ignore most.
|
||||
scale_down_image<T>(images, batch_size, out_height, out_width, channels,
|
||||
xs, ys, output);
|
||||
} else {
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
const float* converted_image = Converter<T>::convert_image_to_float(
|
||||
images, b, in_height, in_width, channels, &converted_image_v);
|
||||
scale_similar_image(converted_image, b, out_height, out_width, channels,
|
||||
in_height, in_width, xs, ys, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -149,7 +149,8 @@ namespace functor {
|
||||
// Partial specialization of ResizeBilinear functor for a GPUDevice.
|
||||
template <typename T>
|
||||
struct ResizeBilinear<GPUDevice, T> {
|
||||
void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor images,
|
||||
void operator()(OpKernelContext* context, const GPUDevice& d,
|
||||
typename TTypes<T, 4>::ConstTensor images,
|
||||
const float height_scale, const float width_scale,
|
||||
typename TTypes<float, 4>::Tensor output) {
|
||||
const int batch = images.dimension(0);
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -39,6 +40,74 @@ class ResizeBilinearOpTest : public OpsTestBase {
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
}
|
||||
|
||||
const Tensor* AddRandomImageInput(const TensorShape& shape) {
|
||||
CHECK_GT(input_types_.size(), inputs_.size())
|
||||
<< "Adding more inputs than types; perhaps you need to call MakeOp";
|
||||
CHECK_EQ(shape.dims(), 4) << "All images must have 4 dimensions.";
|
||||
bool is_ref = IsRefType(input_types_[inputs_.size()]);
|
||||
Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DataTypeToEnum<float>::v(), shape);
|
||||
input->flat<float>().setRandom();
|
||||
tensors_.push_back(input);
|
||||
if (is_ref) {
|
||||
CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
|
||||
DataTypeToEnum<float>::v());
|
||||
inputs_.push_back({&lock_for_refs_, input});
|
||||
} else {
|
||||
CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<float>::v());
|
||||
inputs_.push_back({nullptr, input});
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
// This is the straight forward unoptimized implementation of resize bilinear
|
||||
// We use this to confirm that the optimized version is exactly identical.
|
||||
void ResizeBilinearBaseline(TTypes<float, 4>::ConstTensor images,
|
||||
TTypes<float, 4>::Tensor output) {
|
||||
const int batch = images.dimension(0);
|
||||
const int64 in_height = images.dimension(1);
|
||||
const int64 in_width = images.dimension(2);
|
||||
const int channels = images.dimension(3);
|
||||
|
||||
ASSERT_EQ(batch, output.dimension(0));
|
||||
ASSERT_EQ(channels, output.dimension(3));
|
||||
|
||||
const int64 out_height = output.dimension(1);
|
||||
const int64 out_width = output.dimension(2);
|
||||
|
||||
const float height_scale = in_height / static_cast<float>(out_height);
|
||||
const float width_scale = in_width / static_cast<float>(out_width);
|
||||
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
for (int64 y = 0; y < out_height; ++y) {
|
||||
const float in_y = y * height_scale;
|
||||
const int64 top_y_index = static_cast<int64>(floorf(in_y));
|
||||
const int64 bottom_y_index =
|
||||
std::min(static_cast<int64>(ceilf(in_y)), in_height - 1);
|
||||
const float y_lerp = in_y - top_y_index;
|
||||
for (int64 x = 0; x < out_width; ++x) {
|
||||
const float in_x = x * width_scale;
|
||||
const int64 left_x_index = static_cast<int64>(floorf(in_x));
|
||||
const int64 right_x_index =
|
||||
std::min(static_cast<int64>(ceilf(in_x)), in_width - 1);
|
||||
const float x_lerp = in_x - left_x_index;
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
const float top_left = images(b, top_y_index, left_x_index, c);
|
||||
const float top_right = images(b, top_y_index, right_x_index, c);
|
||||
const float bottom_left =
|
||||
images(b, bottom_y_index, left_x_index, c);
|
||||
const float bottom_right =
|
||||
images(b, bottom_y_index, right_x_index, c);
|
||||
const float top = top_left + (top_right - top_left) * x_lerp;
|
||||
const float bottom =
|
||||
bottom_left + (bottom_right - bottom_left) * x_lerp;
|
||||
output(b, y, x, c) = top + (bottom - top) * y_lerp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ResizeBilinearOpAlignCornersTest : public OpsTestBase {
|
||||
@ -68,6 +137,23 @@ TEST_F(ResizeBilinearOpTest, TestBilinear2x2To1x1) {
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(ResizeBilinearOpTest, TestBilinearRandom2x2To1x1) {
|
||||
const Tensor* input = AddRandomImageInput(TensorShape({1, 2, 2, 1}));
|
||||
AddInputFromArray<int32>(TensorShape({2}), {1, 1});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// When scaling down, we have to arbitrarily pick a pixel from the
|
||||
// original input. In this case, we choose the top/left most pixel.
|
||||
Tensor* output = GetOutput(0);
|
||||
std::unique_ptr<Tensor> expected(
|
||||
new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DataTypeToEnum<float>::v(), TensorShape({1, 1, 1, 1})));
|
||||
ResizeBilinearBaseline(input->tensor<float, 4>(),
|
||||
expected->tensor<float, 4>());
|
||||
EXPECT_EQ(input->flat<float>()(0), output->flat<float>()(0));
|
||||
test::ExpectTensorEqual<float>(*expected.get(), *output);
|
||||
}
|
||||
|
||||
TEST_F(ResizeBilinearOpAlignCornersTest, TestBilinearAlignCorners2x2To1x1) {
|
||||
// Input:
|
||||
// 1, 2
|
||||
@ -302,6 +388,62 @@ TEST_F(ResizeBilinearOpTest, TestBilinear2x2To4x4) {
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(ResizeBilinearOpTest, TestBilinearRandom183x299To299x299) {
|
||||
const TensorShape shape({1, 183, 299, 1});
|
||||
const Tensor* input = AddRandomImageInput(shape);
|
||||
AddInputFromArray<int32>(TensorShape({2}), {299, 299});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::unique_ptr<Tensor> expected(
|
||||
new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DataTypeToEnum<float>::v(), TensorShape({1, 299, 299, 1})));
|
||||
ResizeBilinearBaseline(input->tensor<float, 4>(),
|
||||
expected->tensor<float, 4>());
|
||||
test::ExpectTensorEqual<float>(*expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(ResizeBilinearOpTest, TestBilinearRandom141x186To299x299) {
|
||||
const TensorShape shape({1, 141, 186, 1});
|
||||
const Tensor* input = AddRandomImageInput(shape);
|
||||
AddInputFromArray<int32>(TensorShape({2}), {299, 299});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::unique_ptr<Tensor> expected(
|
||||
new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DataTypeToEnum<float>::v(), TensorShape({1, 299, 299, 1})));
|
||||
ResizeBilinearBaseline(input->tensor<float, 4>(),
|
||||
expected->tensor<float, 4>());
|
||||
test::ExpectTensorEqual<float>(*expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(ResizeBilinearOpTest, TestBilinearRandom749x603To299x299) {
|
||||
const TensorShape shape({1, 749, 603, 1});
|
||||
const Tensor* input = AddRandomImageInput(shape);
|
||||
AddInputFromArray<int32>(TensorShape({2}), {299, 299});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::unique_ptr<Tensor> expected(
|
||||
new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DataTypeToEnum<float>::v(), TensorShape({1, 299, 299, 1})));
|
||||
ResizeBilinearBaseline(input->tensor<float, 4>(),
|
||||
expected->tensor<float, 4>());
|
||||
test::ExpectTensorEqual<float>(*expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(ResizeBilinearOpTest, TestBilinearRandom299x299To299x299) {
|
||||
const TensorShape shape({1, 299, 299, 1});
|
||||
const Tensor* input = AddRandomImageInput(shape);
|
||||
AddInputFromArray<int32>(TensorShape({2}), {299, 299});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
std::unique_ptr<Tensor> expected(
|
||||
new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DataTypeToEnum<float>::v(), TensorShape({1, 299, 299, 1})));
|
||||
ResizeBilinearBaseline(input->tensor<float, 4>(),
|
||||
expected->tensor<float, 4>());
|
||||
test::ExpectTensorEqual<float>(*expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(ResizeBilinearOpTest, TestInvalidOutputSize) {
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
|
||||
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
|
||||
|
Loading…
x
Reference in New Issue
Block a user