[ExpandDimsOp] Micro-optimizations for tf.expand_dims()
.
1. Avoid calling `ctx->allocate_output()` with a dummy value, and instead call `ctx->set_output()` on the reshaped tensor. 2. Compute the expanded shape by writing directly into an `InlinedVector` instead of copying the original shape to an `std::vector`, then using `emplace()` to insert the new value and shift the old ones along. 3. Avoid calling `OpKernelContext::input()` repeatedly. 4. Avoid using `Tensor::flat<Tdim>` to access the axis: instead use `DMAHelper::base` to avoid the shape calculations and CHECK statements. PiperOrigin-RevId: 308634055 Change-Id: I3eb86940943324d98542764506c1e39dcf2b9fa3
This commit is contained in:
parent
f761369203
commit
350027541e
@ -1281,7 +1281,7 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "shape_ops",
|
||||
prefix = "shape_ops",
|
||||
deps = ARRAY_DEPS,
|
||||
deps = ARRAY_DEPS + ["//tensorflow/core/common_runtime:dma_helper"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -2280,6 +2280,25 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "shape_ops_test",
|
||||
size = "small",
|
||||
srcs = ["shape_ops_test.cc"],
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
":shape_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels/data:single_threaded_executor",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "slice_op_test",
|
||||
size = "small",
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -138,41 +140,43 @@ class ExpandDimsOp : public OpKernel {
|
||||
explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
|
||||
const Tensor& input_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, input_t.dtype() != DT_VARIANT,
|
||||
errors::InvalidArgument("ExpandDims on Variant not supported"));
|
||||
|
||||
const Tensor& dim_t = ctx->input(1);
|
||||
OP_REQUIRES(
|
||||
ctx, (ctx->input(1).NumElements() == 1),
|
||||
ctx, (dim_t.NumElements() == 1),
|
||||
errors::InvalidArgument("'dim' must be a tensor with a single value"));
|
||||
Tdim dim = ctx->input(1).flat<Tdim>()(0);
|
||||
OP_REQUIRES(
|
||||
ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
|
||||
errors::InvalidArgument("Tried to expand dim index ", dim,
|
||||
" for tensor with ", ctx->input(0).dims(),
|
||||
" dimensions."));
|
||||
|
||||
auto existing_dims = ctx->input(0).shape().dim_sizes();
|
||||
// Safe - # elements in tensor dims bounded.
|
||||
const int existing_dims_size = static_cast<int>(existing_dims.size());
|
||||
std::vector<int64> new_shape(existing_dims_size);
|
||||
for (size_t i = 0; i < new_shape.size(); ++i) {
|
||||
new_shape[i] = existing_dims[i];
|
||||
}
|
||||
DCHECK_EQ(dim_t.dtype(), DataTypeToEnum<Tdim>::v());
|
||||
Tdim dim = *static_cast<const Tdim*>(DMAHelper::base(&dim_t));
|
||||
const TensorShape& input_shape = input_t.shape();
|
||||
int input_dims = input_shape.dims();
|
||||
OP_REQUIRES(ctx, dim >= -1 - input_dims && dim <= input_dims,
|
||||
errors::InvalidArgument("Tried to expand dim index ", dim,
|
||||
" for tensor with ", input_dims,
|
||||
" dimensions."));
|
||||
|
||||
// We emulate numpy's interpretation of the dim axis when
|
||||
// -input.dims() >= dim <= input.dims().
|
||||
if (dim < 0) {
|
||||
dim += existing_dims.size() + 1;
|
||||
// Clamp to the end if needed.
|
||||
dim = std::min<Tdim>(dim + input_dims + 1, input_dims);
|
||||
}
|
||||
|
||||
// Clamp to the end if needed.
|
||||
dim = std::min<Tdim>(dim, existing_dims_size);
|
||||
new_shape.emplace(new_shape.begin() + dim, 1);
|
||||
const TensorShape output_shape(new_shape);
|
||||
// Compute new shape with an additional dimension.
|
||||
absl::InlinedVector<int64, 8> output_shape_vec(input_dims + 1);
|
||||
for (int64 i = 0; i < dim; ++i) {
|
||||
output_shape_vec[i] = input_shape.dim_size(i);
|
||||
}
|
||||
output_shape_vec[dim] = 1;
|
||||
for (int64 i = dim + 1; i < input_dims + 1; ++i) {
|
||||
output_shape_vec[i] = input_shape.dim_size(i - 1);
|
||||
}
|
||||
TensorShape output_shape(output_shape_vec);
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
|
||||
if (!output->CopyFrom(ctx->input(0), output_shape)) {
|
||||
Tensor output_t;
|
||||
if (!output_t.CopyFrom(input_t, output_shape)) {
|
||||
// This should never happen, since the sizes of the input and output
|
||||
// should always be the same (we only expand the dimension with 1).
|
||||
ctx->SetStatus(
|
||||
@ -180,6 +184,7 @@ class ExpandDimsOp : public OpKernel {
|
||||
ctx->input(0).shape().DebugString(),
|
||||
" and output shape ", output_shape.DebugString()));
|
||||
}
|
||||
ctx->set_output(0, std::move(output_t));
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return false; }
|
||||
|
67
tensorflow/core/kernels/shape_ops_test.cc
Normal file
67
tensorflow/core/kernels/shape_ops_test.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static void BM_ExpandDims(int iters) {
|
||||
testing::StopTiming();
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
|
||||
Tensor input(DT_INT32, TensorShape({1, 1, 1, 1}));
|
||||
input.flat<int32>()(0) = 10;
|
||||
|
||||
Tensor axis(DT_INT32, TensorShape({}));
|
||||
axis.flat<int32>()(0) = 2;
|
||||
|
||||
Node* node;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ExpandDims")
|
||||
.Input(test::graph::Constant(g, input))
|
||||
.Input(test::graph::Constant(g, axis))
|
||||
.Attr("T", DT_INT32)
|
||||
.Attr("Tdim", DT_INT32)
|
||||
.Finalize(g, &node));
|
||||
FixupSourceAndSinkEdges(g);
|
||||
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
|
||||
"SINGLE_THREADED_EXECUTOR")
|
||||
.Run(iters);
|
||||
|
||||
testing::UseRealTime();
|
||||
}
|
||||
|
||||
BENCHMARK(BM_ExpandDims);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user