Test the (general) dot operation via XRT with non-standard layouts, but disable it for now.

PiperOrigin-RevId: 217568591
This commit is contained in:
Roy Frostig 2018-10-17 12:37:49 -07:00 committed by TensorFlower Gardener
parent cf2644ad05
commit 80ca171714

View File

@ -89,6 +89,13 @@ xla::LiteralProto FloatVector(absl::Span<const float> v) {
return array.ToProto();
}
xla::LiteralProto FloatMatrix(
std::initializer_list<std::initializer_list<float>> v,
const xla::Layout& layout) {
auto array = xla::LiteralUtil::CreateR2WithLayout<float>(v, layout);
return array.ToProto();
}
bool CompareLiteralProtos(const xla::LiteralProto& a,
const xla::LiteralProto& b) {
auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
@ -132,6 +139,21 @@ xla::XlaComputation AddAndScale() {
return builder.Build().ValueOrDie();
}
xla::XlaComputation Dot() {
xla::XlaBuilder builder("Dot");
auto p0 = xla::Parameter(
&builder, 0,
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0");
auto p1 = xla::Parameter(
&builder, 1,
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1");
xla::DotDimensionNumbers ddn;
ddn.add_lhs_contracting_dimensions(1);
ddn.add_rhs_contracting_dimensions(0);
xla::DotGeneral(p0, p1, ddn);
return builder.Build().ValueOrDie();
}
xla::XlaComputation AddS64() {
xla::XlaBuilder builder("AddS64");
auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}),
@ -457,6 +479,62 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
xla_program_shape.result().layout()));
}
// Disabled because of failure on TPU (b/117876141)
TEST(RawApiTest, DISABLED_DotGeneralWithLayoutTest) {
auto layout = xla::LayoutUtil::MakeLayout({0, 1});
xrt::XLAAllocation p0;
p0.set_device_ordinal(0);
*p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout);
xrt::XLAAllocation p1;
p1.set_device_ordinal(0);
*p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout);
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1});
*shapes->add_parameters() =
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1});
*shapes->mutable_result() =
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1});
StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
e.set_release_input_handles(true);
e.set_release_compilation_handle(true);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
auto e_config =
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
auto computation =
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto p0_value =
ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
auto p0_handle = ops::XRTAllocate(root, p0_value);
auto p1_value =
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
auto p1_handle = ops::XRTAllocate(root, p1_value);
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
{Output(p0_handle), Output(p1_handle)});
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
ClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({read_back}, &outputs));
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto expected =
xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
TEST(RawApiTest, CompileAndExecuteZeroArg) {
xrt::XLAComputation c;
auto config = c.mutable_config();