[ArgOp] Use CanConsumeArg()
and ConsumeArg()
when possible.
This change allows any graph containing an `ArgOp` to consume its arguments (i.e. avoid holding onto a reference to the arguments in the call frame), which enables buffer forwarding on the argument tensors. See the implementation of `WhileOp` for an example of when this is used. PiperOrigin-RevId: 310164509 Change-Id: I72952b33d5e99f93084c9bd95a6394f33f86fc98
This commit is contained in:
parent
5bb727ee34
commit
abbeddb86f
@ -434,7 +434,33 @@ class ConsumeArgumentCallFrame : public CallFrameInterface {
|
||||
Tensor* const retval_;
|
||||
};
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_ConsumeArgument) {
|
||||
TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_ConsumeArgument_DefaultExecutor) {
|
||||
Init({test::function::XTimesTwo()});
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
TF_CHECK_OK(flr0_->Instantiate(
|
||||
"XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
|
||||
|
||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||
float* x_base_ptr = &x.flat<float>()(0);
|
||||
Tensor y;
|
||||
ConsumeArgumentCallFrame frame(&x, &y);
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
TF_CHECK_OK(Run(flr0_, handle, opts, &frame));
|
||||
|
||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||
|
||||
// Expect that the buffer for `x` has been forwarded to and used as the buffer
|
||||
// for `y`.
|
||||
float* y_base_ptr = &y.flat<float>()(0);
|
||||
EXPECT_EQ(x_base_ptr, y_base_ptr);
|
||||
EXPECT_FALSE(x.IsInitialized());
|
||||
|
||||
TF_CHECK_OK(flr0_->ReleaseHandle(handle));
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest,
|
||||
XTimesTwo_ConsumeArgument_SingleThreadedExecutor) {
|
||||
Init({test::function::XTimesTwo()});
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
instantiate_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
|
||||
|
@ -45,12 +45,27 @@ void ArgOp::Compute(OpKernelContext* ctx) {
|
||||
auto frame = ctx->call_frame();
|
||||
OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
|
||||
const Tensor* val;
|
||||
OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
|
||||
OP_REQUIRES(ctx, val->dtype() == dtype_,
|
||||
errors::InvalidArgument("Type mismatch: actual ",
|
||||
DataTypeString(val->dtype()),
|
||||
" vs. expect ", DataTypeString(dtype_)));
|
||||
ctx->set_output(0, *val);
|
||||
|
||||
auto validate_type = [this](const Tensor& val) {
|
||||
if (val.dtype() == dtype_) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::InvalidArgument("Type mismatch: actual ",
|
||||
DataTypeString(val.dtype()),
|
||||
" vs. expect ", DataTypeString(dtype_));
|
||||
}
|
||||
};
|
||||
|
||||
if (frame->CanConsumeArg(index_)) {
|
||||
Tensor val;
|
||||
frame->ConsumeArg(index_, &val);
|
||||
OP_REQUIRES_OK(ctx, validate_type(val));
|
||||
ctx->set_output(0, std::move(val));
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
|
||||
OP_REQUIRES_OK(ctx, validate_type(*val));
|
||||
ctx->set_output(0, *val);
|
||||
}
|
||||
}
|
||||
|
||||
RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
|
Loading…
Reference in New Issue
Block a user