[XLA:Python] Remove xla_python:: namespace. Refactoring only, no functional changes.
Rename LocalShapedBuffer to xla::PyLocalBuffer to more clearly indicate that it is a Python binding class. Move some XlaComputation helpers into an anonymous namespace inside xla.cc so we don't have the potential for collisions outside the Python bindings. They didn't really fit in the local_client.{cc,h} anyway. PiperOrigin-RevId: 244392801
This commit is contained in:
parent
96ce8df98a
commit
1dca7db621
@ -145,6 +145,7 @@ tf_pybind_extension(
|
|||||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||||
"//tensorflow/compiler/xla/client/lib:svd",
|
"//tensorflow/compiler/xla/client/lib:svd",
|
||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||||
"//tensorflow/compiler/xla/service:platform_util",
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/python/types.h"
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
@ -41,7 +40,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace xla_python {
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
@ -104,7 +102,7 @@ StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
|||||||
return LiteralToPython(absl::make_unique<Literal>(std::move(literal)));
|
return LiteralToPython(absl::make_unique<Literal>(std::move(literal)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static StatusOr<LocalShapedBuffer> TransferHostToDeviceAsync(
|
static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
|
||||||
const PythonBufferTree& tree, int device_ordinal, PyLocalClient* client,
|
const PythonBufferTree& tree, int device_ordinal, PyLocalClient* client,
|
||||||
se::Stream* stream) {
|
se::Stream* stream) {
|
||||||
DeviceMemoryAllocator* allocator =
|
DeviceMemoryAllocator* allocator =
|
||||||
@ -132,37 +130,38 @@ static StatusOr<LocalShapedBuffer> TransferHostToDeviceAsync(
|
|||||||
transfer_manager->TransferLiteralToDeviceAsync(stream, *it, leaf));
|
transfer_manager->TransferLiteralToDeviceAsync(stream, *it, leaf));
|
||||||
++it;
|
++it;
|
||||||
}
|
}
|
||||||
return LocalShapedBuffer(std::move(buffer), client);
|
return PyLocalBuffer(std::move(buffer), client);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
|
StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
|
||||||
const py::object& argument, PyLocalClient* client, int device_ordinal) {
|
PyLocalClient* client,
|
||||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::FromPython");
|
int device_ordinal) {
|
||||||
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython");
|
||||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
||||||
|
|
||||||
// We are done manipulating Python objects; release the GIL.
|
// We are done manipulating Python objects; release the GIL.
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
VLOG(1) << "LocalShapedBuffer::FromPython: shape: " << tree.shape.ToString()
|
VLOG(1) << "PyLocalBuffer::FromPython: shape: " << tree.shape.ToString()
|
||||||
<< " device ordinal: " << device_ordinal;
|
<< " device ordinal: " << device_ordinal;
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
StreamPool::Ptr stream,
|
StreamPool::Ptr stream,
|
||||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
LocalShapedBuffer buffer,
|
PyLocalBuffer buffer,
|
||||||
TransferHostToDeviceAsync(tree, device_ordinal, client, stream.get()));
|
TransferHostToDeviceAsync(tree, device_ordinal, client, stream.get()));
|
||||||
stream->BlockHostUntilDone();
|
stream->BlockHostUntilDone();
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static */ StatusOr<std::vector<LocalShapedBuffer>>
|
/*static */ StatusOr<std::vector<PyLocalBuffer>>
|
||||||
LocalShapedBuffer::FromPythonValues(
|
PyLocalBuffer::FromPythonValues(
|
||||||
const std::vector<std::pair<py::object, int>>& arguments,
|
const std::vector<std::pair<py::object, int>>& arguments,
|
||||||
PyLocalClient* client) {
|
PyLocalClient* client) {
|
||||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::FromPythonValues");
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPythonValues");
|
||||||
int num_arguments = static_cast<int>(arguments.size());
|
int num_arguments = static_cast<int>(arguments.size());
|
||||||
std::vector<LocalShapedBuffer> outputs(num_arguments);
|
std::vector<PyLocalBuffer> outputs(num_arguments);
|
||||||
if (num_arguments == 0) {
|
if (num_arguments == 0) {
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
@ -170,7 +169,7 @@ LocalShapedBuffer::FromPythonValues(
|
|||||||
struct H2DTransfer {
|
struct H2DTransfer {
|
||||||
PythonBufferTree tree;
|
PythonBufferTree tree;
|
||||||
StreamPool::Ptr stream;
|
StreamPool::Ptr stream;
|
||||||
StatusOr<LocalShapedBuffer> buffer;
|
StatusOr<PyLocalBuffer> buffer;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<H2DTransfer> transfers(num_arguments);
|
std::vector<H2DTransfer> transfers(num_arguments);
|
||||||
@ -188,7 +187,7 @@ LocalShapedBuffer::FromPythonValues(
|
|||||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto transfer_h2d = [&](int i) -> StatusOr<LocalShapedBuffer> {
|
auto transfer_h2d = [&](int i) -> StatusOr<PyLocalBuffer> {
|
||||||
int device_ordinal = arguments[i].second;
|
int device_ordinal = arguments[i].second;
|
||||||
return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client,
|
return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client,
|
||||||
transfers[i].stream.get());
|
transfers[i].stream.get());
|
||||||
@ -225,26 +224,26 @@ LocalShapedBuffer::FromPythonValues(
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,
|
PyLocalBuffer::PyLocalBuffer(ScopedShapedBuffer shaped_buffer,
|
||||||
PyLocalClient* client)
|
PyLocalClient* client)
|
||||||
: shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
|
: shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
|
||||||
|
|
||||||
const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
|
const ScopedShapedBuffer* PyLocalBuffer::shaped_buffer() const {
|
||||||
return &shaped_buffer_.value();
|
return &shaped_buffer_.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
ScopedShapedBuffer LocalShapedBuffer::Release() {
|
ScopedShapedBuffer PyLocalBuffer::Release() {
|
||||||
ScopedShapedBuffer result = std::move(*shaped_buffer_);
|
ScopedShapedBuffer result = std::move(*shaped_buffer_);
|
||||||
shaped_buffer_ = absl::nullopt;
|
shaped_buffer_ = absl::nullopt;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
const Shape& LocalShapedBuffer::shape() const {
|
const Shape& PyLocalBuffer::shape() const {
|
||||||
return shaped_buffer()->on_device_shape();
|
return shaped_buffer()->on_device_shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<py::object> LocalShapedBuffer::ToPython() const {
|
StatusOr<py::object> PyLocalBuffer::ToPython() const {
|
||||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::ToPython");
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython");
|
||||||
auto literal = absl::make_unique<Literal>();
|
auto literal = absl::make_unique<Literal>();
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
@ -254,13 +253,13 @@ StatusOr<py::object> LocalShapedBuffer::ToPython() const {
|
|||||||
return LiteralToPython(std::move(literal));
|
return LiteralToPython(std::move(literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
|
StatusOr<std::vector<PyLocalBuffer>> PyLocalBuffer::DestructureTuple() {
|
||||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::DestructureTuple");
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::DestructureTuple");
|
||||||
const Shape tuple_shape = shape();
|
const Shape tuple_shape = shape();
|
||||||
|
|
||||||
if (!tuple_shape.IsTuple()) {
|
if (!tuple_shape.IsTuple()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Attemped to destructure a LocalShapedBuffer that did not have a tuple "
|
"Attemped to destructure a PyLocalBuffer that did not have a tuple "
|
||||||
"shape; shape: %s",
|
"shape; shape: %s",
|
||||||
ShapeUtil::HumanString(tuple_shape));
|
ShapeUtil::HumanString(tuple_shape));
|
||||||
}
|
}
|
||||||
@ -273,7 +272,7 @@ StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
|
|||||||
int device_ordinal = tuple_buffer.device_ordinal();
|
int device_ordinal = tuple_buffer.device_ordinal();
|
||||||
|
|
||||||
ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
|
ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
|
||||||
std::vector<LocalShapedBuffer> results;
|
std::vector<PyLocalBuffer> results;
|
||||||
for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
|
for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
|
||||||
// Create a shaped buffer for this destructured tuple element.
|
// Create a shaped buffer for this destructured tuple element.
|
||||||
const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
|
const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
|
||||||
@ -291,7 +290,7 @@ StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
VLOG(3) << "Completed tuple element: " << i;
|
VLOG(3) << "Completed tuple element: " << i;
|
||||||
results.push_back(LocalShapedBuffer(
|
results.push_back(PyLocalBuffer(
|
||||||
ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_));
|
ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_));
|
||||||
}
|
}
|
||||||
return results;
|
return results;
|
||||||
@ -314,8 +313,8 @@ std::vector<int> PyLocalExecutable::DeviceOrdinals() const {
|
|||||||
return device_ordinals;
|
return device_ordinals;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<LocalShapedBuffer> PyLocalExecutable::Execute(
|
StatusOr<PyLocalBuffer> PyLocalExecutable::Execute(
|
||||||
absl::Span<LocalShapedBuffer* const> argument_handles) {
|
absl::Span<PyLocalBuffer* const> argument_handles) {
|
||||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||||
if (num_replicas() != 1) {
|
if (num_replicas() != 1) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -345,12 +344,11 @@ StatusOr<LocalShapedBuffer> PyLocalExecutable::Execute(
|
|||||||
if (!result_buffer_status.ok()) {
|
if (!result_buffer_status.ok()) {
|
||||||
return result_buffer_status.status();
|
return result_buffer_status.status();
|
||||||
}
|
}
|
||||||
return LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
|
return PyLocalBuffer(std::move(result_buffer_status).ValueOrDie(), client_);
|
||||||
client_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<LocalShapedBuffer>> PyLocalExecutable::ExecutePerReplica(
|
StatusOr<std::vector<PyLocalBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||||
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
|
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles) {
|
||||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica");
|
tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica");
|
||||||
const int num_devices = client_->device_count();
|
const int num_devices = client_->device_count();
|
||||||
|
|
||||||
@ -448,7 +446,7 @@ StatusOr<std::vector<LocalShapedBuffer>> PyLocalExecutable::ExecutePerReplica(
|
|||||||
}
|
}
|
||||||
VLOG(1) << "Replicated execution complete.";
|
VLOG(1) << "Replicated execution complete.";
|
||||||
|
|
||||||
std::vector<LocalShapedBuffer> wrapped_results(num_replicas());
|
std::vector<PyLocalBuffer> wrapped_results(num_replicas());
|
||||||
for (int replica = 0; replica < num_replicas(); ++replica) {
|
for (int replica = 0; replica < num_replicas(); ++replica) {
|
||||||
auto& statusor = results[replica];
|
auto& statusor = results[replica];
|
||||||
if (!statusor.ok()) {
|
if (!statusor.ok()) {
|
||||||
@ -460,46 +458,11 @@ StatusOr<std::vector<LocalShapedBuffer>> PyLocalExecutable::ExecutePerReplica(
|
|||||||
replica));
|
replica));
|
||||||
}
|
}
|
||||||
wrapped_results[replica] =
|
wrapped_results[replica] =
|
||||||
LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
|
PyLocalBuffer(std::move(statusor).ValueOrDie(), client_);
|
||||||
}
|
}
|
||||||
return wrapped_results;
|
return wrapped_results;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<py::bytes> GetComputationSerializedProto(
|
|
||||||
const XlaComputation& computation) {
|
|
||||||
std::string result;
|
|
||||||
if (!computation.proto().SerializeToString(&result)) {
|
|
||||||
return Unknown("Failed to serialize the HloModuleProto.");
|
|
||||||
}
|
|
||||||
return py::bytes(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::string> GetComputationHloText(const XlaComputation& computation) {
|
|
||||||
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
|
||||||
HloModule::CreateModuleConfigFromProto(
|
|
||||||
computation.proto(), GetDebugOptionsFromFlags()));
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
|
||||||
HloModule::CreateFromProto(computation.proto(), module_config));
|
|
||||||
HloPrintOptions options;
|
|
||||||
options = HloPrintOptions::ShortParsable();
|
|
||||||
options.set_print_large_constants(false);
|
|
||||||
return hlo_module->ToString(options);
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::string> GetComputationHloDotGraph(
|
|
||||||
const XlaComputation& computation) {
|
|
||||||
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
|
||||||
HloModule::CreateModuleConfigFromProto(
|
|
||||||
computation.proto(), GetDebugOptionsFromFlags()));
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
|
||||||
HloModule::CreateFromProto(computation.proto(), module_config));
|
|
||||||
return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
|
|
||||||
hlo_module->config().debug_options(),
|
|
||||||
RenderedGraphFormat::kDot);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ StatusOr<std::unique_ptr<PyLocalExecutable>>
|
/*static*/ StatusOr<std::unique_ptr<PyLocalExecutable>>
|
||||||
PyLocalExecutable::Compile(const XlaComputation& computation,
|
PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||||
std::vector<Shape> argument_layouts,
|
std::vector<Shape> argument_layouts,
|
||||||
@ -559,5 +522,4 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
|||||||
std::move(local_executable), std::move(device_assignment), client);
|
std::move(local_executable), std::move(device_assignment), client);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla_python
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -32,7 +32,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace xla_python {
|
|
||||||
|
|
||||||
// Registers a 'fn_capsule' as a CPU custom call target.
|
// Registers a 'fn_capsule' as a CPU custom call target.
|
||||||
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name
|
||||||
@ -74,20 +73,20 @@ class PyLocalClient {
|
|||||||
// Represents a reference to literals that live in a device-allocated buffer via
|
// Represents a reference to literals that live in a device-allocated buffer via
|
||||||
// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a
|
// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a
|
||||||
// literal to device via the local client.
|
// literal to device via the local client.
|
||||||
class LocalShapedBuffer {
|
class PyLocalBuffer {
|
||||||
public:
|
public:
|
||||||
static StatusOr<LocalShapedBuffer> FromPython(
|
static StatusOr<PyLocalBuffer> FromPython(const pybind11::object& argument,
|
||||||
const pybind11::object& argument, PyLocalClient* client,
|
PyLocalClient* client,
|
||||||
int device_ordinal);
|
int device_ordinal);
|
||||||
|
|
||||||
// Converts multiple (python object, device ordinal) pairs into
|
// Converts multiple (python object, device ordinal) pairs into
|
||||||
// LocalShapedBuffers in parallel.
|
// PyLocalBuffers in parallel.
|
||||||
static StatusOr<std::vector<LocalShapedBuffer>> FromPythonValues(
|
static StatusOr<std::vector<PyLocalBuffer>> FromPythonValues(
|
||||||
const std::vector<std::pair<pybind11::object, int>>& argument,
|
const std::vector<std::pair<pybind11::object, int>>& argument,
|
||||||
PyLocalClient* client);
|
PyLocalClient* client);
|
||||||
|
|
||||||
LocalShapedBuffer() = default;
|
PyLocalBuffer() = default;
|
||||||
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, PyLocalClient* client);
|
PyLocalBuffer(ScopedShapedBuffer shaped_buffer, PyLocalClient* client);
|
||||||
StatusOr<pybind11::object> ToPython() const;
|
StatusOr<pybind11::object> ToPython() const;
|
||||||
const Shape& shape() const;
|
const Shape& shape() const;
|
||||||
const ScopedShapedBuffer* shaped_buffer() const;
|
const ScopedShapedBuffer* shaped_buffer() const;
|
||||||
@ -101,9 +100,8 @@ class LocalShapedBuffer {
|
|||||||
client_ = nullptr;
|
client_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destructures a tuple-valued LocalShapedBuffer into its constituent
|
// Destructures a tuple-valued PyLocalBuffer into its constituent elements.
|
||||||
// elements in LocalShapedBufferTuple form.
|
StatusOr<std::vector<PyLocalBuffer>> DestructureTuple();
|
||||||
StatusOr<std::vector<LocalShapedBuffer>> DestructureTuple();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::optional<ScopedShapedBuffer> shaped_buffer_;
|
absl::optional<ScopedShapedBuffer> shaped_buffer_;
|
||||||
@ -133,14 +131,14 @@ class PyLocalExecutable {
|
|||||||
return device_assignment_;
|
return device_assignment_;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<LocalShapedBuffer> Execute(
|
StatusOr<PyLocalBuffer> Execute(
|
||||||
absl::Span<LocalShapedBuffer* const> argument_handles);
|
absl::Span<PyLocalBuffer* const> argument_handles);
|
||||||
|
|
||||||
// Execute on many replicas. Takes a sequence of argument lists (one argument
|
// Execute on many replicas. Takes a sequence of argument lists (one argument
|
||||||
// list per replica) and returns a tuple of results (one result per replica).
|
// list per replica) and returns a tuple of results (one result per replica).
|
||||||
// The number of argument lists must be equal to the replica count.
|
// The number of argument lists must be equal to the replica count.
|
||||||
StatusOr<std::vector<LocalShapedBuffer>> ExecutePerReplica(
|
StatusOr<std::vector<PyLocalBuffer>> ExecutePerReplica(
|
||||||
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles);
|
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles);
|
||||||
|
|
||||||
void Delete() { executable_ = nullptr; }
|
void Delete() { executable_ = nullptr; }
|
||||||
|
|
||||||
@ -150,18 +148,6 @@ class PyLocalExecutable {
|
|||||||
PyLocalClient* const client_;
|
PyLocalClient* const client_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts a computation to a serialized HloModuleProto
|
|
||||||
StatusOr<pybind11::bytes> GetComputationSerializedProto(
|
|
||||||
const XlaComputation& computation);
|
|
||||||
|
|
||||||
// Converts a computation to textual HLO form.
|
|
||||||
StatusOr<std::string> GetComputationHloText(const XlaComputation& computation);
|
|
||||||
|
|
||||||
// Converts a computation to HLO dot graph form.
|
|
||||||
StatusOr<std::string> GetComputationHloDotGraph(
|
|
||||||
const XlaComputation& computation);
|
|
||||||
|
|
||||||
} // namespace xla_python
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
@ -29,16 +31,21 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/python/types.h"
|
#include "tensorflow/compiler/xla/python/types.h"
|
||||||
#include "tensorflow/compiler/xla/python/xrt.h"
|
#include "tensorflow/compiler/xla/python/xrt.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace xla_python {
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
struct Uniquer {
|
struct Uniquer {
|
||||||
absl::Mutex mu;
|
absl::Mutex mu;
|
||||||
NameUniquer name_uniquer GUARDED_BY(mu);
|
NameUniquer name_uniquer GUARDED_BY(mu);
|
||||||
@ -55,6 +62,46 @@ static string UniquifyName(const string& name) {
|
|||||||
return uniquer->name_uniquer.GetUniqueName(name);
|
return uniquer->name_uniquer.GetUniqueName(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts a computation to a serialized HloModuleProto.
|
||||||
|
StatusOr<py::bytes> GetComputationSerializedProto(
|
||||||
|
const XlaComputation& computation) {
|
||||||
|
std::string result;
|
||||||
|
if (!computation.proto().SerializeToString(&result)) {
|
||||||
|
return Unknown("Failed to serialize the HloModuleProto.");
|
||||||
|
}
|
||||||
|
return py::bytes(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts a computation to textual HLO form.
|
||||||
|
StatusOr<std::string> GetComputationHloText(const XlaComputation& computation) {
|
||||||
|
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
||||||
|
HloModule::CreateModuleConfigFromProto(
|
||||||
|
computation.proto(), GetDebugOptionsFromFlags()));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
|
HloModule::CreateFromProto(computation.proto(), module_config));
|
||||||
|
HloPrintOptions options;
|
||||||
|
options = HloPrintOptions::ShortParsable();
|
||||||
|
options.set_print_large_constants(false);
|
||||||
|
return hlo_module->ToString(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts a computation to HLO dot graph form.
|
||||||
|
StatusOr<std::string> GetComputationHloDotGraph(
|
||||||
|
const XlaComputation& computation) {
|
||||||
|
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
||||||
|
HloModule::CreateModuleConfigFromProto(
|
||||||
|
computation.proto(), GetDebugOptionsFromFlags()));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
|
HloModule::CreateFromProto(computation.proto(), module_config));
|
||||||
|
return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
|
||||||
|
hlo_module->config().debug_options(),
|
||||||
|
RenderedGraphFormat::kDot);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
PYBIND11_MODULE(xla_extension, m) {
|
PYBIND11_MODULE(xla_extension, m) {
|
||||||
// Types
|
// Types
|
||||||
py::enum_<PrimitiveType>(m, "PrimitiveType")
|
py::enum_<PrimitiveType>(m, "PrimitiveType")
|
||||||
@ -167,13 +214,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
|
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
|
||||||
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed);
|
.def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed);
|
||||||
|
|
||||||
py::class_<LocalShapedBuffer>(m, "LocalShapedBuffer")
|
py::class_<PyLocalBuffer>(m, "PyLocalBuffer")
|
||||||
.def_static("FromPython", &LocalShapedBuffer::FromPython)
|
.def_static("FromPython", &PyLocalBuffer::FromPython)
|
||||||
.def_static("FromPythonValues", &LocalShapedBuffer::FromPythonValues)
|
.def_static("FromPythonValues", &PyLocalBuffer::FromPythonValues)
|
||||||
.def("Delete", &LocalShapedBuffer::Delete)
|
.def("Delete", &PyLocalBuffer::Delete)
|
||||||
.def("DestructureTuple", &LocalShapedBuffer::DestructureTuple)
|
.def("DestructureTuple", &PyLocalBuffer::DestructureTuple)
|
||||||
.def("ToPython", &LocalShapedBuffer::ToPython)
|
.def("ToPython", &PyLocalBuffer::ToPython)
|
||||||
.def("shape", &LocalShapedBuffer::shape);
|
.def("shape", &PyLocalBuffer::shape);
|
||||||
|
|
||||||
py::class_<PyLocalExecutable>(m, "LocalExecutable")
|
py::class_<PyLocalExecutable>(m, "LocalExecutable")
|
||||||
.def_static("Compile", &PyLocalExecutable::Compile,
|
.def_static("Compile", &PyLocalExecutable::Compile,
|
||||||
@ -433,5 +480,4 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
tensorflow::AddXrtSubmodule(&m);
|
tensorflow::AddXrtSubmodule(&m);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla_python
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -113,11 +113,10 @@ class LocalBackend(Backend):
|
|||||||
return self.client.DeviceCount()
|
return self.client.DeviceCount()
|
||||||
|
|
||||||
def buffer_from_pyval(self, pyval, device=0):
|
def buffer_from_pyval(self, pyval, device=0):
|
||||||
return _xla.LocalShapedBuffer.FromPython(pyval, self.client, device)
|
return _xla.PyLocalBuffer.FromPython(pyval, self.client, device)
|
||||||
|
|
||||||
def buffers_from_pyvals(self, pyvals_and_devices):
|
def buffers_from_pyvals(self, pyvals_and_devices):
|
||||||
return _xla.LocalShapedBuffer.FromPythonValues(pyvals_and_devices,
|
return _xla.PyLocalBuffer.FromPythonValues(pyvals_and_devices, self.client)
|
||||||
self.client)
|
|
||||||
|
|
||||||
def delete_buffer(self, c_buffer):
|
def delete_buffer(self, c_buffer):
|
||||||
c_buffer.Delete()
|
c_buffer.Delete()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user