Extend XRTExecute to accept vector-packed argument handles.

Rather than a list of scalars, the op now accepts a list of scalars-or-vectors, whose concatenation is treated as the argument list of input handles.

PiperOrigin-RevId: 217796092
This commit is contained in:
Roy Frostig 2018-10-18 17:11:06 -07:00 committed by TensorFlower Gardener
parent 11c4c33f51
commit fda1b4a91f
3 changed files with 85 additions and 6 deletions
tensorflow/compiler/xrt

View File

@ -70,14 +70,30 @@ Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
std::vector<XRTTupleAllocation*>* input_tuples,
std::vector<xla::ShapedBuffer>* input_allocations,
std::vector<xla::ShapedBuffer*>* input_pointers) {
std::vector<int64> input_uids;
OpInputList arg_list;
TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list));
input_tuples->resize(arg_list.size());
input_pointers->resize(arg_list.size());
// Concatenate all input uids from list of scalars-or-vectors carrying them.
for (int i = 0; i < arg_list.size(); ++i) {
TF_RET_CHECK(TensorShapeUtils::IsScalar(arg_list[i].shape()));
int64 input_uid = arg_list[i].scalar<int64>()();
const Tensor& arg = arg_list[i];
if (TensorShapeUtils::IsScalar(arg.shape())) {
input_uids.push_back(arg.scalar<int64>()());
} else {
TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
auto arg_vec = arg.vec<int64>();
const int64 num_elts = arg.shape().dim_size(0);
for (int i = 0; i < num_elts; ++i) {
input_uids.push_back(arg_vec(i));
}
}
}
// Retrieve allocations for the uids.
input_tuples->resize(input_uids.size());
input_pointers->resize(input_uids.size());
for (int i = 0; i < input_uids.size(); ++i) {
const int64 input_uid = input_uids[i];
TF_RETURN_IF_ERROR(
XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i]));
if (release_inputs) {
@ -90,7 +106,7 @@ Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
XRTTupleAllocation* tuple = (*input_tuples)[i];
input_allocations->emplace_back(tuple->ToShapedBuffer());
}
for (int i = 0; i < arg_list.size(); ++i) {
for (int i = 0; i < input_uids.size(); ++i) {
(*input_pointers)[i] = &(*input_allocations)[i];
}
return Status::OK();

View File

@ -26,7 +26,16 @@ REGISTER_OP("XRTExecute")
.Input("execution_config: string")
.Input("input_handles: Ninputs * int64")
.Output("output_handle: int64")
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<shape_inference::ShapeHandle> input_handle_shapes;
TF_RETURN_IF_ERROR(c->input("input_handles", &input_handle_shapes));
for (size_t i = 0; i < input_handle_shapes.size(); ++i) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(
c->WithRankAtMost(input_handle_shapes[i], 1, &unused));
}
return tensorflow::shape_inference::ScalarShape(c);
})
.Doc(
R"(
Runs a previously-compiled computation on a core. If

View File

@ -416,6 +416,60 @@ TEST(RawApiTest, CompileAndExecute) {
EXPECT_EQ(program_shape.parameters_size(), 2);
}
TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
xrt::XLAAllocation p0;
p0.set_device_ordinal(0);
*p0.mutable_value() = FloatVector({1.0f, 2.0f});
xrt::XLAAllocation p1;
p1.set_device_ordinal(0);
*p1.mutable_value() = FloatVector({8.0f, 5.0f});
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
e.set_release_input_handles(true);
e.set_release_compilation_handle(true);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
auto e_config =
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
auto computation =
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto p0_value =
ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
auto p0_handle = ops::XRTAllocate(root, p0_value);
auto p1_value =
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
auto p1_handle = ops::XRTAllocate(root, p1_value);
auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"),
{Output(p0_handle), Output(p1_handle)});
auto result =
ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)});
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
ClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
xla::ProgramShape program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
}
TEST(RawApiTest, CompileWithXlaReturnShapes) {
xla::XlaBuilder builder("XrtXlaShapes");
auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});