Test the (general) dot operation via XRT with non-standard layouts, but disable it for now.
PiperOrigin-RevId: 217568591
This commit is contained in:
parent
cf2644ad05
commit
80ca171714
@ -89,6 +89,13 @@ xla::LiteralProto FloatVector(absl::Span<const float> v) {
|
||||
return array.ToProto();
|
||||
}
|
||||
|
||||
xla::LiteralProto FloatMatrix(
|
||||
std::initializer_list<std::initializer_list<float>> v,
|
||||
const xla::Layout& layout) {
|
||||
auto array = xla::LiteralUtil::CreateR2WithLayout<float>(v, layout);
|
||||
return array.ToProto();
|
||||
}
|
||||
|
||||
bool CompareLiteralProtos(const xla::LiteralProto& a,
|
||||
const xla::LiteralProto& b) {
|
||||
auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
|
||||
@ -132,6 +139,21 @@ xla::XlaComputation AddAndScale() {
|
||||
return builder.Build().ValueOrDie();
|
||||
}
|
||||
|
||||
xla::XlaComputation Dot() {
|
||||
xla::XlaBuilder builder("Dot");
|
||||
auto p0 = xla::Parameter(
|
||||
&builder, 0,
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0");
|
||||
auto p1 = xla::Parameter(
|
||||
&builder, 1,
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1");
|
||||
xla::DotDimensionNumbers ddn;
|
||||
ddn.add_lhs_contracting_dimensions(1);
|
||||
ddn.add_rhs_contracting_dimensions(0);
|
||||
xla::DotGeneral(p0, p1, ddn);
|
||||
return builder.Build().ValueOrDie();
|
||||
}
|
||||
|
||||
xla::XlaComputation AddS64() {
|
||||
xla::XlaBuilder builder("AddS64");
|
||||
auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}),
|
||||
@ -457,6 +479,62 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
||||
xla_program_shape.result().layout()));
|
||||
}
|
||||
|
||||
// Disabled because of failure on TPU (b/117876141)
|
||||
TEST(RawApiTest, DISABLED_DotGeneralWithLayoutTest) {
|
||||
auto layout = xla::LayoutUtil::MakeLayout({0, 1});
|
||||
|
||||
xrt::XLAAllocation p0;
|
||||
p0.set_device_ordinal(0);
|
||||
*p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout);
|
||||
xrt::XLAAllocation p1;
|
||||
p1.set_device_ordinal(0);
|
||||
*p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout);
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1});
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1});
|
||||
*shapes->mutable_result() =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1});
|
||||
StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
e.set_release_input_handles(true);
|
||||
e.set_release_compilation_handle(true);
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
auto e_config =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
|
||||
auto computation =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, computation);
|
||||
auto p0_value =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
|
||||
auto p0_handle = ops::XRTAllocate(root, p0_value);
|
||||
auto p1_value =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
|
||||
auto p1_handle = ops::XRTAllocate(root, p1_value);
|
||||
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
|
||||
{Output(p0_handle), Output(p1_handle)});
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
ClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({read_back}, &outputs));
|
||||
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
|
||||
|
||||
auto expected =
|
||||
xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
|
||||
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, CompileAndExecuteZeroArg) {
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
|
Loading…
Reference in New Issue
Block a user