Switch CallFrameInterface::GetArg(int, Tensor*)
to take (int, const Tensor**)
.
This change is a micro-optimization that speeds up execution of `ArgOp::Compute()` in every `tf.function` invocation or `DirectSession::Run()` call. Previously, we would copy-construct a Tensor in the implementation of `GetArg()`, and then move that into `OpKernelContext::set_output()`. In effect, this involves two refcount operations on the underlying buffer, and three copies of the `TensorShape`. By instead outputting a pointer to the `const Tensor` in the frame, we avoid one of the refcount operations, and two of the `TensorShape` copies. One consequence of this change is that it becomes more difficult to create a `Tensor` on the fly in `GetArg()`. We were using that ability in two places: 1. In `DirectSession::RunCallable()` when one of the arguments has type `DT_RESOURCE`, and it is converted into a tensor (part of the tfdbg functionality, rarely used via this API). We fix that by (in the rare case it is necessary) performing the conversion eagerly in `RunCallable()`. 2. In the `MapDefunOp` implementation, when one of the arguments is to be sliced out of the tensor we're mapping over, the slice is created in `GetArg()`. We fix this by adding a mutable vector of slices to the specialized `CallFrame` implementation, storing the created tensor there, and returning a pointer to it. (Since `MapDefunOp` is only used in a graph rewrite context, a better fix here would be to add explicit `tf.slice()` ops to the graph, instead of relying on the call frame to do this work, because these might be possible to optimize further with Grappler.) PiperOrigin-RevId: 305983898 Change-Id: I0834777c27cd97204e8e3df052a08faf0dcf68f9
This commit is contained in:
parent
0a3f4e1533
commit
fe19e92cea
@ -39,19 +39,19 @@ class XlaArgOp : public XlaOpKernel {
|
||||
// compilation. Use the usual implementation of _Arg.
|
||||
auto frame = ctx->call_frame();
|
||||
if (frame != nullptr) {
|
||||
Tensor val;
|
||||
const Tensor* val;
|
||||
OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
|
||||
// Types that cannot be copied using memcpy (like DT_STRING) are wrapped
|
||||
// in a DT_UINT8 and hence the type mismatches. Skip the test in such
|
||||
// cases. See XlaOpKernelContext::SetOutputExpression for details.
|
||||
if (DataTypeCanUseMemcpy(dtype_)) {
|
||||
OP_REQUIRES(ctx, val.dtype() == dtype_,
|
||||
OP_REQUIRES(ctx, val->dtype() == dtype_,
|
||||
errors::InvalidArgument(
|
||||
"Type mismatch: actual ", DataTypeString(val.dtype()),
|
||||
"Type mismatch: actual ", DataTypeString(val->dtype()),
|
||||
" vs. expect ", DataTypeString(dtype_)));
|
||||
}
|
||||
// Forwards the argument from the frame.
|
||||
ctx->op_kernel_context()->set_output(0, val);
|
||||
ctx->op_kernel_context()->set_output(0, *val);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1867,14 +1867,11 @@ class DirectSession::RunCallableCallFrame : public CallFrameInterface {
|
||||
return executors_and_keys_->output_types.size();
|
||||
}
|
||||
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
if (index > feed_tensors_->size()) {
|
||||
Status GetArg(int index, const Tensor** val) override {
|
||||
if (TF_PREDICT_FALSE(index > feed_tensors_->size())) {
|
||||
return errors::Internal("Args index out of bounds: ", index);
|
||||
} else if (executors_and_keys_->input_types[index] == DT_RESOURCE) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
session_->ResourceHandleToInputTensor((*feed_tensors_)[index], val));
|
||||
} else {
|
||||
*val = (*feed_tensors_)[index];
|
||||
*val = &(*feed_tensors_)[index];
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1947,16 +1944,37 @@ class DirectSession::RunCallableCallFrame : public CallFrameInterface {
|
||||
}
|
||||
|
||||
size_t input_size = 0;
|
||||
bool any_resource_feeds = false;
|
||||
for (auto& tensor : feed_tensors) {
|
||||
input_size += tensor.AllocatedBytes();
|
||||
any_resource_feeds = any_resource_feeds || tensor.dtype() == DT_RESOURCE;
|
||||
}
|
||||
metrics::RecordGraphInputTensors(input_size);
|
||||
|
||||
std::unique_ptr<std::vector<Tensor>> converted_feed_tensors;
|
||||
const std::vector<Tensor>* actual_feed_tensors;
|
||||
|
||||
if (TF_PREDICT_FALSE(any_resource_feeds)) {
|
||||
converted_feed_tensors = absl::make_unique<std::vector<Tensor>>();
|
||||
converted_feed_tensors->reserve(feed_tensors.size());
|
||||
for (const Tensor& t : feed_tensors) {
|
||||
if (t.dtype() == DT_RESOURCE) {
|
||||
converted_feed_tensors->emplace_back();
|
||||
Tensor* tensor_from_handle = &converted_feed_tensors->back();
|
||||
TF_RETURN_IF_ERROR(ResourceHandleToInputTensor(t, tensor_from_handle));
|
||||
} else {
|
||||
converted_feed_tensors->emplace_back(t);
|
||||
}
|
||||
}
|
||||
actual_feed_tensors = converted_feed_tensors.get();
|
||||
} else {
|
||||
actual_feed_tensors = &feed_tensors;
|
||||
}
|
||||
|
||||
// A specialized CallFrame implementation that takes advantage of the
|
||||
// optimized RunCallable interface.
|
||||
|
||||
RunCallableCallFrame call_frame(this, executors_and_keys.get(), &feed_tensors,
|
||||
fetch_tensors);
|
||||
RunCallableCallFrame call_frame(this, executors_and_keys.get(),
|
||||
actual_feed_tensors, fetch_tensors);
|
||||
|
||||
if (LogMemory::IsEnabled()) {
|
||||
LogMemory::RecordStep(step_id, run_state_args.handle);
|
||||
|
@ -1434,9 +1434,9 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
std::vector<Tensor> args;
|
||||
args.reserve(frame->num_args());
|
||||
for (size_t i = 0; i < frame->num_args(); ++i) {
|
||||
Tensor arg;
|
||||
const Tensor* arg;
|
||||
Status s = frame->GetArg(i, &arg);
|
||||
args.push_back(std::move(arg));
|
||||
args.emplace_back(*arg);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
}
|
||||
|
@ -1118,12 +1118,12 @@ Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
|
||||
Status FunctionCallFrame::GetArg(int index, const Tensor** val) {
|
||||
if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
|
||||
return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
|
||||
args_.size(), ")");
|
||||
}
|
||||
*val = args_[index];
|
||||
*val = &args_[index];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -274,7 +274,7 @@ class CallFrameInterface {
|
||||
virtual size_t num_args() const = 0;
|
||||
virtual size_t num_retvals() const = 0;
|
||||
|
||||
virtual Status GetArg(int index, Tensor* val) const = 0;
|
||||
virtual Status GetArg(int index, const Tensor** val) = 0;
|
||||
virtual Status SetRetval(int index, const Tensor& val) = 0;
|
||||
};
|
||||
|
||||
@ -301,7 +301,7 @@ class FunctionCallFrame : public CallFrameInterface {
|
||||
size_t num_retvals() const override { return ret_types_.size(); }
|
||||
|
||||
// Callee methods.
|
||||
Status GetArg(int index, Tensor* val) const override;
|
||||
Status GetArg(int index, const Tensor** val) override;
|
||||
Status SetRetval(int index, const Tensor& val) override;
|
||||
|
||||
private:
|
||||
|
@ -912,9 +912,9 @@ TEST(FunctionCallFrame, Void_Void) {
|
||||
TF_EXPECT_OK(frame.SetArgs({}));
|
||||
auto a = test::AsTensor<float>({100});
|
||||
HasError(frame.SetArgs({a}), "Invalid argument");
|
||||
Tensor v;
|
||||
const Tensor* v;
|
||||
HasError(frame.GetArg(0, &v), "Invalid argument");
|
||||
HasError(frame.SetRetval(0, v), "Invalid argument");
|
||||
HasError(frame.SetRetval(0, *v), "Invalid argument");
|
||||
std::vector<Tensor> rets;
|
||||
TF_EXPECT_OK(frame.GetRetvals(&rets));
|
||||
EXPECT_EQ(rets.size(), 0);
|
||||
@ -930,28 +930,28 @@ TEST(FunctionCallFrame, Float_Float_Float) {
|
||||
"Invalid argument: Expects arg[1] to be float");
|
||||
TF_EXPECT_OK(frame.SetArgs({a, b}));
|
||||
|
||||
Tensor v;
|
||||
const Tensor* v;
|
||||
HasError(frame.GetArg(-1, &v), "Invalid argument");
|
||||
HasError(frame.GetArg(2, &v), "Invalid argument");
|
||||
TF_EXPECT_OK(frame.GetArg(0, &v));
|
||||
test::ExpectTensorEqual<float>(a, v);
|
||||
test::ExpectTensorEqual<float>(a, *v);
|
||||
TF_EXPECT_OK(frame.GetArg(1, &v));
|
||||
test::ExpectTensorEqual<float>(b, v);
|
||||
test::ExpectTensorEqual<float>(b, *v);
|
||||
|
||||
v = test::AsTensor<float>({-100});
|
||||
HasError(frame.SetRetval(-1, v), "Invalid argument");
|
||||
HasError(frame.SetRetval(1, v), "Invalid argument");
|
||||
Tensor w = test::AsTensor<float>({-100});
|
||||
HasError(frame.SetRetval(-1, w), "Invalid argument");
|
||||
HasError(frame.SetRetval(1, w), "Invalid argument");
|
||||
HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})),
|
||||
"Invalid argument: Expects ret[0] to be float");
|
||||
|
||||
std::vector<Tensor> rets;
|
||||
HasError(frame.GetRetvals(&rets), "does not have value");
|
||||
TF_EXPECT_OK(frame.SetRetval(0, v));
|
||||
HasError(frame.SetRetval(0, v), "has already been set");
|
||||
TF_EXPECT_OK(frame.SetRetval(0, *v));
|
||||
HasError(frame.SetRetval(0, *v), "has already been set");
|
||||
|
||||
TF_EXPECT_OK(frame.GetRetvals(&rets));
|
||||
EXPECT_EQ(rets.size(), 1);
|
||||
test::ExpectTensorEqual<float>(rets[0], v);
|
||||
test::ExpectTensorEqual<float>(rets[0], *v);
|
||||
}
|
||||
|
||||
TEST(Canonicalize, Basic) {
|
||||
|
@ -309,14 +309,12 @@ class OwnedArgsCallFrame : public CallFrameBase {
|
||||
}
|
||||
|
||||
// Callee methods.
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
Status GetArg(int index, const Tensor** val) override {
|
||||
if (index < args_.size()) {
|
||||
// TODO(mrry): Consider making `CallFrameInterface::GetArg` non-const in
|
||||
// order to be able to `std::move(args_[index])` into `*val`.
|
||||
*val = args_[index];
|
||||
*val = &args_[index];
|
||||
return Status::OK();
|
||||
} else if (index < args_.size() + captured_inputs_->size()) {
|
||||
*val = (*captured_inputs_)[index - args_.size()];
|
||||
*val = &(*captured_inputs_)[index - args_.size()];
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::InvalidArgument("Argument ", index, " is out of range.");
|
||||
@ -342,12 +340,12 @@ class BorrowedArgsCallFrame : public CallFrameBase {
|
||||
}
|
||||
|
||||
// Callee methods.
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
Status GetArg(int index, const Tensor** val) override {
|
||||
if (index < args_.size()) {
|
||||
*val = args_[index];
|
||||
*val = &args_[index];
|
||||
return Status::OK();
|
||||
} else if (index < args_.size() + captured_inputs_->size()) {
|
||||
*val = (*captured_inputs_)[index - args_.size()];
|
||||
*val = &(*captured_inputs_)[index - args_.size()];
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::InvalidArgument("Argument ", index, " is out of range.");
|
||||
|
@ -77,7 +77,10 @@ class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface {
|
||||
public:
|
||||
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
|
||||
size_t iter)
|
||||
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
|
||||
: compute_opts_(compute_opts),
|
||||
kernel_(kernel),
|
||||
iter_(iter),
|
||||
sliced_args_(compute_opts_->args.size()) {}
|
||||
|
||||
~MapFunctionCallFrame() override = default;
|
||||
|
||||
@ -87,7 +90,7 @@ class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface {
|
||||
return static_cast<size_t>(kernel_->num_outputs());
|
||||
}
|
||||
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
Status GetArg(int index, const Tensor** val) override {
|
||||
if (index < 0 || index >= compute_opts_->args.size() +
|
||||
compute_opts_->captured_inputs.size()) {
|
||||
return errors::InvalidArgument("Mismatch in number of function inputs.");
|
||||
@ -95,19 +98,24 @@ class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface {
|
||||
|
||||
if (index >= compute_opts_->args.size()) {
|
||||
// The function is calling for a captured input
|
||||
*val = compute_opts_->captured_inputs[index - compute_opts_->args.size()];
|
||||
*val =
|
||||
&compute_opts_->captured_inputs[index - compute_opts_->args.size()];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool result =
|
||||
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
|
||||
compute_opts_->arg_shapes.at(index));
|
||||
// NOTE: If contention on mu_ becomes problematic, we could create a vector
|
||||
// of mutexes, each guarding a different element of sliced_args_.
|
||||
mutex_lock l(mu_);
|
||||
bool result = sliced_args_[index].CopyFrom(
|
||||
compute_opts_->args[index].Slice(iter_, iter_ + 1),
|
||||
compute_opts_->arg_shapes.at(index));
|
||||
if (!result) {
|
||||
return errors::Internal("GetArg failed.");
|
||||
} else if (!val->IsAligned()) {
|
||||
} else if (!sliced_args_[index].IsAligned()) {
|
||||
// Ensure alignment
|
||||
*val = tensor::DeepCopy(*val);
|
||||
sliced_args_[index] = tensor::DeepCopy(sliced_args_[index]);
|
||||
}
|
||||
*val = &sliced_args_[index];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -152,6 +160,8 @@ class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface {
|
||||
ComputeOptions* const compute_opts_; // Not owned
|
||||
const OpKernel* kernel_;
|
||||
const size_t iter_;
|
||||
mutex mu_;
|
||||
std::vector<Tensor> sliced_args_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
MapDefunOp::MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
|
@ -284,14 +284,11 @@ class SingleThreadedExecutorImpl : public Executor {
|
||||
for (size_t i = 0; i < arg_output_locations_.size(); ++i) {
|
||||
const size_t num_destinations = arg_output_locations_[i].size();
|
||||
if (num_destinations > 0) {
|
||||
Tensor arg;
|
||||
const Tensor* arg;
|
||||
TF_CHECK_OK(args.call_frame->GetArg(i, &arg));
|
||||
for (size_t j = 0; j < num_destinations - 1; ++j) {
|
||||
inputs[arg_output_locations_[i][j]].Init(arg);
|
||||
for (size_t j = 0; j < num_destinations; ++j) {
|
||||
inputs[arg_output_locations_[i][j]].Init(*arg);
|
||||
}
|
||||
// Move `arg` to the last consumer to avoid the cost of copying it.
|
||||
inputs[arg_output_locations_[i][num_destinations - 1]].Init(
|
||||
std::move(arg));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,13 +44,13 @@ ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
void ArgOp::Compute(OpKernelContext* ctx) {
|
||||
auto frame = ctx->call_frame();
|
||||
OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
|
||||
Tensor val;
|
||||
const Tensor* val;
|
||||
OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
|
||||
OP_REQUIRES(ctx, val.dtype() == dtype_,
|
||||
OP_REQUIRES(ctx, val->dtype() == dtype_,
|
||||
errors::InvalidArgument("Type mismatch: actual ",
|
||||
DataTypeString(val.dtype()),
|
||||
DataTypeString(val->dtype()),
|
||||
" vs. expect ", DataTypeString(dtype_)));
|
||||
ctx->set_output(0, std::move(val));
|
||||
ctx->set_output(0, *val);
|
||||
}
|
||||
|
||||
RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
|
Loading…
Reference in New Issue
Block a user