Extend XRTExecute to accept vector-packed argument handles.
Rather than a list of scalars, the op now accepts a list of scalars-or-vectors, whose concatenation is treated as the argument list of input handles. PiperOrigin-RevId: 217796092
This commit is contained in:
parent
11c4c33f51
commit
fda1b4a91f
tensorflow/compiler/xrt
@ -70,14 +70,30 @@ Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
|
||||
std::vector<XRTTupleAllocation*>* input_tuples,
|
||||
std::vector<xla::ShapedBuffer>* input_allocations,
|
||||
std::vector<xla::ShapedBuffer*>* input_pointers) {
|
||||
std::vector<int64> input_uids;
|
||||
OpInputList arg_list;
|
||||
TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list));
|
||||
|
||||
input_tuples->resize(arg_list.size());
|
||||
input_pointers->resize(arg_list.size());
|
||||
// Concatenate all input uids from list of scalars-or-vectors carrying them.
|
||||
for (int i = 0; i < arg_list.size(); ++i) {
|
||||
TF_RET_CHECK(TensorShapeUtils::IsScalar(arg_list[i].shape()));
|
||||
int64 input_uid = arg_list[i].scalar<int64>()();
|
||||
const Tensor& arg = arg_list[i];
|
||||
if (TensorShapeUtils::IsScalar(arg.shape())) {
|
||||
input_uids.push_back(arg.scalar<int64>()());
|
||||
} else {
|
||||
TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape()));
|
||||
auto arg_vec = arg.vec<int64>();
|
||||
const int64 num_elts = arg.shape().dim_size(0);
|
||||
for (int i = 0; i < num_elts; ++i) {
|
||||
input_uids.push_back(arg_vec(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve allocations for the uids.
|
||||
input_tuples->resize(input_uids.size());
|
||||
input_pointers->resize(input_uids.size());
|
||||
for (int i = 0; i < input_uids.size(); ++i) {
|
||||
const int64 input_uid = input_uids[i];
|
||||
TF_RETURN_IF_ERROR(
|
||||
XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i]));
|
||||
if (release_inputs) {
|
||||
@ -90,7 +106,7 @@ Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
|
||||
XRTTupleAllocation* tuple = (*input_tuples)[i];
|
||||
input_allocations->emplace_back(tuple->ToShapedBuffer());
|
||||
}
|
||||
for (int i = 0; i < arg_list.size(); ++i) {
|
||||
for (int i = 0; i < input_uids.size(); ++i) {
|
||||
(*input_pointers)[i] = &(*input_allocations)[i];
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -26,7 +26,16 @@ REGISTER_OP("XRTExecute")
|
||||
.Input("execution_config: string")
|
||||
.Input("input_handles: Ninputs * int64")
|
||||
.Output("output_handle: int64")
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
std::vector<shape_inference::ShapeHandle> input_handle_shapes;
|
||||
TF_RETURN_IF_ERROR(c->input("input_handles", &input_handle_shapes));
|
||||
for (size_t i = 0; i < input_handle_shapes.size(); ++i) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRankAtMost(input_handle_shapes[i], 1, &unused));
|
||||
}
|
||||
return tensorflow::shape_inference::ScalarShape(c);
|
||||
})
|
||||
.Doc(
|
||||
R"(
|
||||
Runs a previously-compiled computation on a core. If
|
||||
|
@ -416,6 +416,60 @@ TEST(RawApiTest, CompileAndExecute) {
|
||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||
}
|
||||
|
||||
TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
|
||||
xrt::XLAAllocation p0;
|
||||
p0.set_device_ordinal(0);
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f});
|
||||
xrt::XLAAllocation p1;
|
||||
p1.set_device_ordinal(0);
|
||||
*p1.mutable_value() = FloatVector({8.0f, 5.0f});
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
e.set_release_input_handles(true);
|
||||
e.set_release_compilation_handle(true);
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
auto e_config =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
|
||||
auto computation =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, computation);
|
||||
auto p0_value =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
|
||||
auto p0_handle = ops::XRTAllocate(root, p0_value);
|
||||
auto p1_value =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
|
||||
auto p1_handle = ops::XRTAllocate(root, p1_value);
|
||||
auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"),
|
||||
{Output(p0_handle), Output(p1_handle)});
|
||||
auto result =
|
||||
ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)});
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
ClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
|
||||
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
|
||||
|
||||
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
|
||||
xla::ProgramShape program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||
}
|
||||
|
||||
TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
||||
xla::XlaBuilder builder("XrtXlaShapes");
|
||||
auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
|
||||
|
Loading…
Reference in New Issue
Block a user