This change is a micro-optimization that speeds up execution of `ArgOp::Compute()` in every `tf.function` invocation or `DirectSession::Run()` call. Previously, we would copy-construct a Tensor in the implementation of `GetArg()`, and then move that into `OpKernelContext::set_output()`. In effect, this involves two refcount operations on the underlying buffer, and three copies of the `TensorShape`. By instead outputting a pointer to the `const Tensor` in the frame, we avoid one of the refcount operations, and two of the `TensorShape` copies. One consequence of this change is that it becomes more difficult to create a `Tensor` on the fly in `GetArg()`. We were using that ability in two places: 1. In `DirectSession::RunCallable()` when one of the arguments has type `DT_RESOURCE`, and it is converted into a tensor (part of the tfdbg functionality, rarely used via this API). We fix that by (in the rare case it is necessary) performing the conversion eagerly in `RunCallable()`. 2. In the `MapDefunOp` implementation, when one of the arguments is to be sliced out of the tensor we're mapping over, the slice is created in `GetArg()`. We fix this by adding a mutable vector of slices to the specialized `CallFrame` implementation, storing the created tensor there, and returning a pointer to it. (Since `MapDefunOp` is only used in a graph rewrite context, a better fix here would be to add explicit `tf.slice()` ops to the graph, instead of relying on the call frame to do this work, because these might be possible to optimize further with Grappler.) PiperOrigin-RevId: 305983898 Change-Id: I0834777c27cd97204e8e3df052a08faf0dcf68f9
80 lines
3.0 KiB
C++
80 lines
3.0 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
// This OpKernel implements the _Arg Op for XLA JIT devices. It
|
|
// associates its output with one of the arguments to a
|
|
// subcomputation.
|
|
class XlaArgOp : public XlaOpKernel {
|
|
public:
|
|
explicit XlaArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
|
|
}
|
|
|
|
void Compile(XlaOpKernelContext* ctx) override {
|
|
// If 'frame' is non-null, this is a function call inside an outer JIT
|
|
// compilation. Use the usual implementation of _Arg.
|
|
auto frame = ctx->call_frame();
|
|
if (frame != nullptr) {
|
|
const Tensor* val;
|
|
OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
|
|
// Types that cannot be copied using memcpy (like DT_STRING) are wrapped
|
|
// in a DT_UINT8 and hence the type mismatches. Skip the test in such
|
|
// cases. See XlaOpKernelContext::SetOutputExpression for details.
|
|
if (DataTypeCanUseMemcpy(dtype_)) {
|
|
OP_REQUIRES(ctx, val->dtype() == dtype_,
|
|
errors::InvalidArgument(
|
|
"Type mismatch: actual ", DataTypeString(val->dtype()),
|
|
" vs. expect ", DataTypeString(dtype_)));
|
|
}
|
|
// Forwards the argument from the frame.
|
|
ctx->op_kernel_context()->set_output(0, *val);
|
|
return;
|
|
}
|
|
|
|
const XlaExpression& arg = ctx->xla_context()->args()[index_];
|
|
OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid,
|
|
errors::InvalidArgument("Invalid/missing argument expression"));
|
|
if (ctx->expected_output_dtype(0) == DT_VARIANT) {
|
|
ctx->SetTensorListOutput(0, arg.handle());
|
|
} else {
|
|
ctx->SetOutputExpression(0, arg);
|
|
}
|
|
}
|
|
|
|
private:
|
|
int index_;
|
|
DataType dtype_;
|
|
|
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp);
|
|
};
|
|
|
|
REGISTER_XLA_OP(
|
|
Name("_Arg").AllowResourceTypes().AllowVariantTypes().CompilationOnly(),
|
|
XlaArgOp);
|
|
|
|
} // namespace tensorflow
|