STT-tensorflow/tensorflow/compiler/xla/python/local_computation_builder.h
2018-10-30 08:45:40 -07:00

481 lines
17 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#include <string>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace swig {
// Initializes the number of replicas that XLA will be initialized with (when
// first obtaining a handle to the local XLA service). If this is called after
// the handle to the local XLA service has been established, then an error is
// returned.
Status InitializeReplicaCount(int replica_count);
// Returns the replica count that is currently set, regardless of whether the
// local XLA service has been instantiated yet or not.
int GetReplicaCount();
// Wraps the local client's infeed-transfer function.
//
// The default device ordinal (0) is used.
Status TransferToInfeedLocal(const Literal& literal);
// Transfers the given literal to the infeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
// Transfers a literal of the given shape from the outfeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
int replica_number);
// Represents a reference to literals that live in a device-allocated buffer via
// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a
// literal to device via the local client.
class LocalShapedBuffer {
public:
static StatusOr<LocalShapedBuffer*> FromLiteral(
const Literal& argument, const absl::optional<Shape>& shape_with_layout);
LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
StatusOr<Literal> ToLiteral() const;
const Shape& shape() const;
const ScopedShapedBuffer* shaped_buffer() const;
// Transfers ownership of the encapsulated ShapedBuffer to the caller,
// analogous to std::unique_ptr::release().
ShapedBuffer Release();
private:
ScopedShapedBuffer shaped_buffer_;
};
// Result of a tuple destructuring operation on a LocalShapedBuffer -- this
// appears to be a simpler mechanism for the time being than an alternative like
// using SWIG to transform std::vectors into Python lists of SWIG objects
// directly.
class LocalShapedBufferTuple {
public:
// Note: any LocalShapedBuffer elements that are not Release()'d will be
// deallocated in the destructor.
explicit LocalShapedBufferTuple(std::vector<LocalShapedBuffer*> elements);
~LocalShapedBufferTuple();
// Releases the ith element to the caller. Further attempts to release the ith
// element will return an invalid argument error.
StatusOr<LocalShapedBuffer*> Release(int i);
// Returns the number of elements in the destructured tuple.
int64 size() const;
private:
std::vector<LocalShapedBuffer*> elements_;
};
// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements
// in LocalShapedBufferTuple form.
StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
LocalShapedBuffer* local_shaped_buffer);
// Represents a reference to literals that live in a device-allocated buffer via
// XRT. Specifically, wraps an int64 handle produced by running the allocation
// graph, and an XLA shape to track the referent's shape.
class XrtAllocation {
public:
// Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which allocation and deallocation
// graphs are run.
static StatusOr<XrtAllocation*> FromLiteral(const Literal& argument,
const string& session_target);
XrtAllocation(int64 handle, Shape shape, const string& session_target);
~XrtAllocation();
StatusOr<Literal> ToLiteral() const;
const Shape& shape() const;
const int64 handle() const;
private:
const int64 handle_;
const Shape shape_;
const string session_target_;
};
// Result of a tuple destructuring operation on an XrtAllocation.
class XrtAllocationTuple {
public:
// Note: any XrtAllocation elements that are not Release()'d will be
// deallocated in the destructor.
explicit XrtAllocationTuple(std::vector<XrtAllocation*> elements);
~XrtAllocationTuple();
// Releases the ith element to the caller. Further attempts to release the ith
// element will return an invalid argument error.
StatusOr<XrtAllocation*> Release(int i);
// Returns the number of elements in the destructured tuple.
int64 size() const;
private:
std::vector<XrtAllocation*> elements_;
};
// Destructures a tuple-valued XrtAllocation into its constitutent elements
// in XrtAllocationTuple form.
//
// Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which the sub-tupling graph is run,
// and passed along in constructing each constituent XrtAllocation.
StatusOr<XrtAllocationTuple*> DestructureXrtAllocationTuple(
XrtAllocation* allocation, const string& session_target);
// Represents a compiled computation that can be executed given handles to
// device-allocated literals. Specifically, wraps an XLA LocalExecutable.
class CompiledLocalComputation {
public:
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
StatusOr<LocalShapedBuffer*> Execute(
absl::Span<LocalShapedBuffer* const> argument_handles);
private:
std::unique_ptr<LocalExecutable> executable_;
};
// Represents a compiled computation that can be executed given handles to
// device-allocated literals. Specifically, wraps an XRT computation handle.
class CompiledXrtComputation {
public:
// Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which the execution graph is run.
CompiledXrtComputation(const ProgramShape& program_shape, int64 handle,
const string& session_target);
~CompiledXrtComputation();
StatusOr<XrtAllocation*> Execute(
absl::Span<XrtAllocation* const> argument_handles);
const ProgramShape& program_shape() const;
int64 handle() const;
private:
const ProgramShape program_shape_;
const int64 handle_;
const string session_target_;
};
// Wraps a XlaComputation produced by a LocalComputationBuilder. The
// Compile method compiles the computation to a (local) executable via
// the client library's local client. This class is intended to be
// made available to Python via SWIG.
class LocalComputation {
public:
LocalComputation(XlaComputation computation);
StatusOr<CompiledLocalComputation*> Compile(
const std::vector<Shape>& argument_shapes,
const ExecutableBuildOptions* build_options);
// Accepts a `session_target` argument, used in constructing the
// `tensorflow::ClientSession` instance in which the compilation graph is run.
StatusOr<CompiledXrtComputation*> CompileForXrt(
const std::vector<Shape>& argument_shapes, const string& session_target);
const XlaComputation& computation() const;
// Returns the HloModuleProto contained in the XlaComputation in the
// serialized binary format. Logs an internal error and returns an empty
// string on failure.
string GetSerializedProto() const;
// Returns the return-value shape for this computation.
StatusOr<Shape> GetReturnValueShape() const;
private:
XlaComputation computation_;
};
// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended
// to be made available to Python via SWIG.
class LocalOp {
public:
LocalOp(const XlaOp& op);
const XlaOp& op() const;
private:
XlaOp op_;
};
// Wraps the ComputationBuilder API in order to:
// - Support consumption by SWIG in order to be made available to
// Python.
// - Set up the underlying builder to use the client library's
// LocalClient.
// - Wrap Computations in LocalComputations for Python access.
// - Correspondingly unwrap incoming LocalComputations.
class LocalComputationBuilder {
public:
LocalComputationBuilder(const string& computation_name);
void SetOpMetadata(const OpMetadata& metadata);
void ClearOpMetadata();
// Returns an owned LocalComputation to the caller on success.
StatusOr<LocalComputation*> Build();
// Returns an owned LocalComputation to the caller on success with given root.
StatusOr<LocalComputation*> BuildWithRoot(const LocalOp& root);
LocalOp Parameter(int64 parameter_number, const Shape& shape,
const string& name);
StatusOr<Shape> GetShape(const LocalOp& operand);
// Returns the shape of the current return value for the computation.
StatusOr<Shape> GetReturnValueShape();
LocalOp Infeed(const Shape& shape);
void Outfeed(const LocalOp& operand, const Shape& shape,
const string& outfeed_config);
LocalOp ConstantLiteral(const Literal& literal);
LocalOp Broadcast(const LocalOp& operand,
absl::Span<const int64> broadcast_sizes);
LocalOp BroadcastInDim(const LocalOp& operand, const Shape& shape,
absl::Span<const int64> broadcast_dimensions);
LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value,
const PaddingConfig& padding_config);
LocalOp Reshape(const LocalOp& operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes);
LocalOp Collapse(const LocalOp& operand, absl::Span<const int64> dimensions);
LocalOp CrossReplicaSum(const LocalOp& operand);
LocalOp Slice(const LocalOp& operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides);
LocalOp SliceInDim(const LocalOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno);
LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices,
absl::Span<const int64> slice_sizes);
LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update,
const LocalOp& start_indices);
LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension);
LocalOp SelectAndScatterWithGeneralPadding(
const LocalOp& operand, const LocalComputation& select,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source,
const LocalOp& init_value, const LocalComputation& scatter);
LocalOp Tuple(absl::Span<const LocalOp> elements);
LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index);
LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs);
LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs,
const DotDimensionNumbers& dimension_numbers);
LocalOp ConvGeneralDilated(
const LocalOp& lhs, const LocalOp& rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64> > padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count);
LocalOp ConvertElementType(const LocalOp& operand,
PrimitiveType new_element_type);
LocalOp BitcastConvertType(const LocalOp& operand,
PrimitiveType new_element_type);
LocalOp Call(const LocalComputation& local_computation,
absl::Span<const LocalOp> operands);
LocalOp Transpose(const LocalOp& operand,
absl::Span<const int64> permutation);
LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions);
LocalOp Map(absl::Span<const LocalOp> operands,
const LocalComputation& local_computation,
absl::Span<const int64> dimensions);
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
absl::Span<const int64> dimensions_to_reduce);
LocalOp ReduceWindowWithGeneralPadding(
const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
absl::Span<const int64> base_dilations,
absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64> > padding);
LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
const Shape& shape);
LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape);
LocalOp While(const LocalComputation& condition, const LocalComputation& body,
const LocalOp& init);
LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand,
const LocalComputation& true_computation,
const LocalOp& false_operand,
const LocalComputation& false_computation);
StatusOr<bool> IsConstant(const LocalOp& operand);
LocalOp Sort(const LocalOp& operand, int64 dimension);
LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values,
int64 dimension);
StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand);
#define _FORWARD(method_name, return_sig, args_sig) \
return_sig method_name args_sig;
#define _FORWARD_UNOP(method_name) \
_FORWARD(method_name, LocalOp, (const LocalOp& operand))
#define _FORWARD_BINOP(method_name) \
_FORWARD(method_name, LocalOp, \
(const LocalOp& lhs, const LocalOp& rhs, \
absl::Span<const int64> broadcast_dimensions))
#define _FORWARD_TRIOP(method_name) \
_FORWARD(method_name, LocalOp, \
(const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs))
_FORWARD_TRIOP(Select)
_FORWARD_TRIOP(Clamp)
_FORWARD_BINOP(Eq)
_FORWARD_BINOP(Ne)
_FORWARD_BINOP(Ge)
_FORWARD_BINOP(Gt)
_FORWARD_BINOP(Lt)
_FORWARD_BINOP(Le)
_FORWARD_BINOP(Add)
_FORWARD_BINOP(Sub)
_FORWARD_BINOP(Mul)
_FORWARD_BINOP(Div)
_FORWARD_BINOP(Rem)
_FORWARD_BINOP(Max)
_FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
_FORWARD_BINOP(Xor)
_FORWARD_BINOP(ShiftLeft)
_FORWARD_BINOP(ShiftRightArithmetic)
_FORWARD_BINOP(ShiftRightLogical)
_FORWARD_BINOP(Atan2)
_FORWARD_BINOP(Pow)
_FORWARD_BINOP(Complex)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Expm1)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Log1p)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sqrt)
_FORWARD_UNOP(Rsqrt)
_FORWARD_UNOP(Square)
_FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Erfc)
_FORWARD_UNOP(Erf)
_FORWARD_UNOP(ErfInv)
_FORWARD_UNOP(Lgamma)
_FORWARD_UNOP(Digamma)
_FORWARD_UNOP(Acos)
_FORWARD_UNOP(Asin)
_FORWARD_UNOP(Atan)
_FORWARD_UNOP(Tan)
_FORWARD_UNOP(Acosh)
_FORWARD_UNOP(Asinh)
_FORWARD_UNOP(Atanh)
_FORWARD_UNOP(Cosh)
_FORWARD_UNOP(Sinh)
_FORWARD_UNOP(Real)
_FORWARD_UNOP(Imag)
_FORWARD_UNOP(Conj)
#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
#undef _FORWARD_TRIOP
private:
XlaBuilder builder_;
};
// Functions for freeing resources from the Python side.
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer);
void DeleteXrtAllocation(XrtAllocation* allocation);
void DeleteCompiledLocalComputation(CompiledLocalComputation* computation);
void DeleteCompiledXrtComputation(CompiledXrtComputation* computation);
void DeleteLocalComputation(LocalComputation* computation);
} // namespace swig
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_