From bb354b5f649dc39b005c4a85d72c1ea79d00ad07 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 2 Apr 2020 10:40:52 -0700 Subject: [PATCH] [xla] Expose XlaComputation constructor in the Python API PiperOrigin-RevId: 304430744 Change-Id: Iaef4f9e749b226c821aafcae440cdfc590828473 --- tensorflow/compiler/xla/python/xla.cc | 6 ++++++ tensorflow/compiler/xla/python/xla_client_test.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 2affd4b30fa..0e6a3c8ac17 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -1206,6 +1206,12 @@ PYBIND11_MODULE(xla_extension, m) { py::return_value_policy::reference, py::keep_alive<1, 0>()); py::class_(m, "XlaComputation") + .def(py::init([](const py::bytes& serialized_hlo_module_proto) + -> std::unique_ptr { + HloModuleProto proto; + proto.ParseFromString(serialized_hlo_module_proto); + return absl::make_unique(proto); + })) .def("GetProgramShape", &XlaComputation::GetProgramShape) .def("GetSerializedProto", &GetComputationSerializedProto) .def("GetHloText", &GetComputationHloText) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 36d5da2841b..95b760965d8 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -381,6 +381,21 @@ class ComputationsWithConstantsTest(ComputationTest): self._ExecuteAndCompareClose(c, expected=[0.75]) +class ComputationFromProtoTest(absltest.TestCase): + """Test computation execution from HLO proto.""" + + def testExecuteFromProto(self): + # Build the HLO proto + b = xla_client.ComputationBuilder("computation") + b.Add(b.Constant(np.int8(1)), b.Constant(np.int8(2))) + serialized_proto = b.Build().GetSerializedProto() + + # Load and execute the proto + c = xla_client.Computation(xla_client._xla.XlaComputation(serialized_proto)) + ans, = xla_client.execute_with_python_values(c.Compile()) + np.testing.assert_equal(ans, np.int8(3)) + + class ParametersTest(ComputationTest): """Tests focusing on Parameter ops and argument-passing."""