add the cost estimator
This commit is contained in:
parent
2002d5e283
commit
d7757cbe27
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/image_resizer_state.h"
|
#include "tensorflow/core/kernels/image_resizer_state.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -166,7 +165,14 @@ struct ResizeNearestNeighbor<CPUDevice, T, half_pixel_centers, align_corners> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
Eigen::Index N = batch_size * out_height * out_width;
|
Eigen::Index N = batch_size * out_height * out_width;
|
||||||
d.parallelFor(N, Eigen::TensorOpCost(0, 0, 1000.0), ParallelResize);
|
const int input_bytes =
|
||||||
|
batch_size * in_height * in_width * channels * sizeof(T);
|
||||||
|
const int output_bytes = N * channels * sizeof(T);
|
||||||
|
const int compute_cycles = (Eigen::TensorOpCost::ModCost<T>() * 2 +
|
||||||
|
Eigen::TensorOpCost::DivCost<T>() * 5) *
|
||||||
|
N;
|
||||||
|
const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
|
||||||
|
d.parallelFor(N, cost, ParallelResize);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_RESIZE_NEAREST_NEIGHBOR_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_RESIZE_NEAREST_NEIGHBOR_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_RESIZE_NEAREST_NEIGHBOR_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_RESIZE_NEAREST_NEIGHBOR_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user