[XLA] Rename all (Mutable)ArraySlice to absl::Span.
PiperOrigin-RevId: 210998142
This commit is contained in:
parent
5d5591fbd4
commit
6f879f891a
@ -111,7 +111,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) {
|
||||
|
||||
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
|
||||
StringPiece target_triple,
|
||||
gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed) {
|
||||
absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
|
||||
GetTargetMachineFromTriple(target_triple));
|
||||
|
||||
|
@ -84,7 +84,7 @@ struct ProtobufToEmbed {
|
||||
// EmbeddedProtocolBuffers instance.
|
||||
StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
|
||||
StringPiece target_triple,
|
||||
gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed);
|
||||
absl::Span<const ProtobufToEmbed> protobufs_to_embed);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -108,7 +108,7 @@ class Predicate {
|
||||
|
||||
virtual string ToString() const = 0;
|
||||
int64 hash() const { return hash_; }
|
||||
virtual gtl::ArraySlice<Predicate*> GetOperands() const = 0;
|
||||
virtual absl::Span<Predicate* const> GetOperands() const = 0;
|
||||
|
||||
virtual Kind kind() const = 0;
|
||||
virtual ~Predicate() {}
|
||||
@ -129,7 +129,7 @@ class Predicate {
|
||||
};
|
||||
|
||||
int64 HashPredicateSequence(Predicate::Kind kind,
|
||||
gtl::ArraySlice<Predicate*> preds) {
|
||||
absl::Span<Predicate* const> preds) {
|
||||
int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind);
|
||||
for (Predicate* pred : preds) {
|
||||
hash = Hash64Combine(hash, pred->hash());
|
||||
@ -159,8 +159,10 @@ class AndPredicate : public Predicate {
|
||||
|
||||
Kind kind() const override { return Kind::kAnd; }
|
||||
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
|
||||
gtl::ArraySlice<Predicate*> operands() const { return operands_; }
|
||||
absl::Span<Predicate* const> GetOperands() const override {
|
||||
return operands_;
|
||||
}
|
||||
absl::Span<Predicate* const> operands() const { return operands_; }
|
||||
|
||||
private:
|
||||
std::vector<Predicate*> operands_;
|
||||
@ -187,8 +189,10 @@ class OrPredicate : public Predicate {
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kOr; }
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
|
||||
gtl::ArraySlice<Predicate*> operands() const { return operands_; }
|
||||
absl::Span<Predicate* const> GetOperands() const override {
|
||||
return operands_;
|
||||
}
|
||||
absl::Span<Predicate* const> operands() const { return operands_; }
|
||||
|
||||
private:
|
||||
std::vector<Predicate*> operands_;
|
||||
@ -207,7 +211,9 @@ class NotPredicate : public Predicate {
|
||||
|
||||
Kind kind() const override { return Kind::kNot; }
|
||||
Predicate* operand() const { return operands_[0]; }
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
|
||||
absl::Span<Predicate* const> GetOperands() const override {
|
||||
return operands_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<Predicate*, 1> operands_;
|
||||
@ -240,7 +246,9 @@ class AndRecurrencePredicate : public Predicate {
|
||||
|
||||
Kind kind() const override { return Kind::kAndRecurrence; }
|
||||
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return operands_; }
|
||||
absl::Span<Predicate* const> GetOperands() const override {
|
||||
return operands_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<Predicate*, 2> operands_;
|
||||
@ -264,7 +272,7 @@ class SymbolPredicate : public Predicate {
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kSymbol; }
|
||||
gtl::ArraySlice<Predicate*> GetOperands() const override { return {}; }
|
||||
absl::Span<Predicate* const> GetOperands() const override { return {}; }
|
||||
|
||||
// If `must_be_true()` is true this SymbolPredicate represents the proposition
|
||||
// "tensor_id() is live and evaluates to true".
|
||||
@ -313,11 +321,11 @@ template <typename FunctionTy>
|
||||
// them.
|
||||
class PredicateFactory {
|
||||
public:
|
||||
Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
|
||||
Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
|
||||
return MakeAndOrImpl(operands, /*is_and=*/true);
|
||||
}
|
||||
|
||||
Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
|
||||
Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
|
||||
return MakeAndOrImpl(operands, /*is_and=*/false);
|
||||
}
|
||||
|
||||
@ -374,7 +382,7 @@ class PredicateFactory {
|
||||
new PredicateT(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
|
||||
Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
|
||||
|
||||
// Predicate instances are interned, meaning that there is only a single
|
||||
// instance of a Predicate object with a given content. This makes checking
|
||||
@ -387,7 +395,7 @@ class PredicateFactory {
|
||||
// for the owning pointers to predicate instances.
|
||||
|
||||
using SignatureForAndOr =
|
||||
std::pair<Predicate::Kind, gtl::ArraySlice<Predicate*>>;
|
||||
std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
|
||||
using SignatureForNot = Predicate*;
|
||||
using SignatureForAndRec = std::pair<Predicate*, Predicate*>;
|
||||
using SignatureForSymbol = std::pair<SafeTensorId, bool>;
|
||||
@ -422,8 +430,8 @@ class PredicateFactory {
|
||||
};
|
||||
|
||||
// Common code to create AndPredicate or OrPredicate instances.
|
||||
Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
|
||||
bool is_and) {
|
||||
Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
absl::Span<Predicate* const> operands, bool is_and) {
|
||||
Predicate::Kind pred_kind =
|
||||
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
|
||||
gtl::FlatSet<Predicate*> simplified_ops_set;
|
||||
@ -474,7 +482,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
|
||||
// NB! Because we'll use a non-owning reference to simplified_ops in the
|
||||
// key for interned_and_or_instances_ we need to be careful to std::move()
|
||||
// it all the way through.
|
||||
gtl::ArraySlice<Predicate*> operands_slice = simplified_ops;
|
||||
absl::Span<Predicate* const> operands_slice = simplified_ops;
|
||||
std::unique_ptr<Predicate> new_pred =
|
||||
is_and ? Make<AndPredicate>(std::move(simplified_ops))
|
||||
: Make<OrPredicate>(std::move(simplified_ops));
|
||||
@ -496,7 +504,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
|
||||
|
||||
Status Populate();
|
||||
Status PopulateWithReversePostOrder(gtl::ArraySlice<Node*> rpo);
|
||||
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
|
||||
bool HasInputsWithMismatchingDeadness(const Node& node) override;
|
||||
void Print() const override;
|
||||
gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
|
||||
@ -527,7 +535,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
}
|
||||
}
|
||||
|
||||
void SetPredicate(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred,
|
||||
void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
|
||||
std::vector<bool>* should_revisit) {
|
||||
for (int output_idx : output_idxs) {
|
||||
SetPredicate(n, output_idx, pred, should_revisit);
|
||||
@ -625,7 +633,7 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
|
||||
}
|
||||
|
||||
std::vector<Predicate*> and_ops;
|
||||
gtl::ArraySlice<Predicate*> recurrent_pred_ops =
|
||||
absl::Span<Predicate* const> recurrent_pred_ops =
|
||||
backedge_predicate->GetOperands();
|
||||
|
||||
bool found_sym = false;
|
||||
@ -784,7 +792,7 @@ Status DeadnessAnalysisImpl::Populate() {
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
gtl::ArraySlice<Node*> rpo) {
|
||||
absl::Span<Node* const> rpo) {
|
||||
// This an abstract interpretation over the deadness propagation semantics of
|
||||
// the graph executor.
|
||||
//
|
||||
@ -924,7 +932,7 @@ Status ComputePredicates(const Graph& graph,
|
||||
}
|
||||
|
||||
Status ComputePredicates(const Graph& graph,
|
||||
gtl::ArraySlice<Node*> reverse_post_order,
|
||||
absl::Span<Node* const> reverse_post_order,
|
||||
PredicateMapTy* out_predicate_map) {
|
||||
DeadnessAnalysisImpl impl(&graph);
|
||||
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
|
||||
|
@ -32,7 +32,7 @@ Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
|
||||
// specified in `reverse_post_order` which must be a valid RPO for the graph
|
||||
// minus NextIteration->Merge edges.
|
||||
Status ComputePredicates(const Graph& graph,
|
||||
gtl::ArraySlice<Node*> reverse_post_order,
|
||||
absl::Span<Node* const> reverse_post_order,
|
||||
PredicateMapTy* out_predicate_map);
|
||||
} // namespace deadness_analysis_internal
|
||||
} // namespace tensorflow
|
||||
|
@ -379,7 +379,7 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("InputTestShaped", opts);
|
||||
}
|
||||
|
||||
Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice<int>& shape,
|
||||
Node* KnownShapeBase(DataType dtype, absl::Span<const int> shape,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
|
||||
@ -394,7 +394,7 @@ Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice<int>& shape,
|
||||
.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
Node* KnownShape(const gtl::ArraySlice<int>& shape,
|
||||
Node* KnownShape(absl::Span<const int> shape,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
return KnownShapeBase(DT_FLOAT, shape, opts);
|
||||
}
|
||||
@ -417,8 +417,7 @@ Node* KeyPlaceholder(const string& call_node,
|
||||
}
|
||||
|
||||
Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
|
||||
const string& oc_cluster,
|
||||
const gtl::ArraySlice<DataType>& dtypes,
|
||||
const string& oc_cluster, absl::Span<const DataType> dtypes,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
string key =
|
||||
@ -892,13 +891,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"C:o:0", "c:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O1"},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}},
|
||||
{"c"}},
|
||||
},
|
||||
@ -1038,26 +1037,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
{{"outside_compilation_O2_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"F:o:0", "D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors",
|
||||
gtl::ArraySlice<string>({"outside_compilation_O1_host_compute"})},
|
||||
absl::Span<const string>({"outside_compilation_O1_host_compute"})},
|
||||
{"key", "host_compute_channel_F1_O2"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O2"},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O2"}},
|
||||
{"F", "outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"C:o:0", "D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O1"},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}},
|
||||
{"D"}},
|
||||
},
|
||||
@ -1190,13 +1189,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"C:o:0", "D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O1"},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}},
|
||||
{"D"}},
|
||||
},
|
||||
@ -1213,13 +1212,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"G:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F2_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes",
|
||||
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"}}},
|
||||
},
|
||||
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
|
||||
@ -1364,13 +1363,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"C:o:0", "D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O1"},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}},
|
||||
{"D"}},
|
||||
},
|
||||
@ -1386,13 +1385,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"G:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F2_O1"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F2_O1"},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}}},
|
||||
},
|
||||
{{"i_0_retval", "I:o:0"}});
|
||||
@ -1495,13 +1494,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes",
|
||||
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"}}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
@ -1579,13 +1578,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes",
|
||||
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})},
|
||||
absl::Span<const TensorShapeProto>({shape_proto_expected})},
|
||||
{"_outside_compilation_subgraph", "O1"}},
|
||||
{"D"}},
|
||||
},
|
||||
@ -1661,12 +1660,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
@ -1742,12 +1741,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
@ -1846,13 +1845,13 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{{"outside_compilation_O2_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"F:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O2"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O2"},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"}}},
|
||||
},
|
||||
{{"h_0_retval", "H:o:0"}});
|
||||
@ -1955,13 +1954,13 @@ TEST(EncapsulateSubgraphsTest,
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O1"},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}}},
|
||||
},
|
||||
{{"h_0_retval", "H:o:0"}});
|
||||
@ -2066,37 +2065,37 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O1"},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}}},
|
||||
{{"outside_compilation_O2_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({})},
|
||||
{"ancestors",
|
||||
gtl::ArraySlice<string>({"outside_compilation_O1_host_compute"})},
|
||||
absl::Span<const string>({"outside_compilation_O1_host_compute"})},
|
||||
{"key", "host_compute_channel_F1_O2"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O2"}},
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O3_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({})},
|
||||
{"ancestors",
|
||||
gtl::ArraySlice<string>({"outside_compilation_O1_host_compute",
|
||||
absl::Span<const string>({"outside_compilation_O1_host_compute",
|
||||
"outside_compilation_O2_host_compute"})},
|
||||
{"key", "host_compute_channel_F1_O3"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})},
|
||||
{"shapes", absl::Span<const TensorShapeProto>({})},
|
||||
{"_outside_compilation_subgraph", "O3"}},
|
||||
{"outside_compilation_O1_host_compute",
|
||||
"outside_compilation_O2_host_compute"}}},
|
||||
@ -2272,13 +2271,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"XlaHostCompute",
|
||||
{"c:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"ancestors", gtl::ArraySlice<string>({})},
|
||||
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
|
||||
{"ancestors", absl::Span<const string>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph",
|
||||
"_outside_compilation_shape_inference_F1_O1"},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})},
|
||||
{"shapes", absl::Span<const DataType>({})},
|
||||
{"_outside_compilation_subgraph", "O1"}},
|
||||
{"c"}},
|
||||
},
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
|
||||
gtl::ArraySlice<Node*> post_order) {
|
||||
absl::Span<Node* const> post_order) {
|
||||
// Find nodes that have at least one user outside their cluster that expects
|
||||
// hostmem output. These nodes should be cloned to outside the cluster to
|
||||
// avoid the device-host copy we'd otherwise need.
|
||||
|
@ -275,13 +275,13 @@ class OpTest : public ::testing::Test {
|
||||
|
||||
// Select a random element from 'candidates'.
|
||||
template <typename T>
|
||||
T Choose(gtl::ArraySlice<T> candidates);
|
||||
T Choose(absl::Span<const T> candidates);
|
||||
|
||||
static constexpr int kDefaultMaxRank = 5;
|
||||
static constexpr int64 kDefaultMaxDimensionSize = 256LL;
|
||||
|
||||
// Returns true if 'dims' have a size less than tf_xla_max_tensor_size.
|
||||
bool TensorSizeIsOk(gtl::ArraySlice<int64> dims);
|
||||
bool TensorSizeIsOk(absl::Span<const int64> dims);
|
||||
|
||||
// Returns a random dimension size, in the range [min, max).
|
||||
int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
|
||||
@ -307,11 +307,11 @@ class OpTest : public ::testing::Test {
|
||||
// of the type's range. If the shape is omitted, a random shape is used.
|
||||
// TODO(phawkins): generalize this code to a caller-supplied distribution.
|
||||
Tensor RandomTensor(DataType dtype, bool needs_unique_values,
|
||||
gtl::ArraySlice<int64> shape);
|
||||
absl::Span<const int64> shape);
|
||||
Tensor RandomTensor(DataType dtype);
|
||||
|
||||
// Like RandomTensor, but uses values >= 0.
|
||||
Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice<int64> shape);
|
||||
Tensor RandomNonNegativeTensor(DataType dtype, absl::Span<const int64> shape);
|
||||
Tensor RandomNonNegativeTensor(DataType dtype);
|
||||
|
||||
// Returns a random subset of the integers in the range [0, rank), suitable
|
||||
@ -415,7 +415,7 @@ void OpTest::Repeatedly(const std::function<TestResult(void)>& fn) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T OpTest::Choose(gtl::ArraySlice<T> candidates) {
|
||||
T OpTest::Choose(absl::Span<const T> candidates) {
|
||||
std::uniform_int_distribution<size_t> d(0, candidates.size() - 1);
|
||||
return candidates[d(generator())];
|
||||
}
|
||||
@ -425,7 +425,7 @@ int64 OpTest::RandomDim(int64 min, int64 max) {
|
||||
return size_distribution(generator());
|
||||
}
|
||||
|
||||
bool OpTest::TensorSizeIsOk(gtl::ArraySlice<int64> dims) {
|
||||
bool OpTest::TensorSizeIsOk(absl::Span<const int64> dims) {
|
||||
int64 size = 1LL;
|
||||
for (int64 dim : dims) {
|
||||
size *= dim;
|
||||
@ -451,7 +451,7 @@ std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
|
||||
}
|
||||
|
||||
Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
|
||||
gtl::ArraySlice<int64> shape) {
|
||||
absl::Span<const int64> shape) {
|
||||
Tensor tensor(dtype, TensorShape(shape));
|
||||
switch (dtype) {
|
||||
case DT_FLOAT: {
|
||||
@ -548,7 +548,7 @@ Tensor OpTest::RandomTensor(DataType dtype) {
|
||||
}
|
||||
|
||||
Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
|
||||
gtl::ArraySlice<int64> shape) {
|
||||
absl::Span<const int64> shape) {
|
||||
Tensor tensor(dtype, TensorShape(shape));
|
||||
switch (dtype) {
|
||||
case DT_FLOAT: {
|
||||
@ -1884,7 +1884,7 @@ TEST_F(OpTest, DynamicStitch) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
TensorShape shape(index_dims[i]);
|
||||
Tensor t = test::AsTensor<int32>(
|
||||
gtl::ArraySlice<int32>(indices).subspan(pos, shape.num_elements()),
|
||||
absl::Span<const int32>(indices).subspan(pos, shape.num_elements()),
|
||||
shape);
|
||||
builder.Input(t);
|
||||
pos += t.NumElements();
|
||||
|
@ -805,10 +805,10 @@ TEST(FunctionalizeControlFlow, Complex) {
|
||||
auto assign = ops::AssignAddVariableOp(
|
||||
scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
|
||||
|
||||
auto one =
|
||||
ops::Const<int32>(scope.WithOpName("outer/inner/One")
|
||||
auto one = ops::Const<int32>(
|
||||
scope.WithOpName("outer/inner/One")
|
||||
.WithControlDependencies(
|
||||
gtl::ArraySlice<Operation>{assign.operation}),
|
||||
absl::Span<const Operation>{assign.operation}),
|
||||
1);
|
||||
auto add_j =
|
||||
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
|
||||
@ -823,7 +823,7 @@ TEST(FunctionalizeControlFlow, Complex) {
|
||||
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
|
||||
auto add_i =
|
||||
ops::Add(scope.WithOpName("outer/add")
|
||||
.WithControlDependencies(gtl::ArraySlice<Operation>{
|
||||
.WithControlDependencies(absl::Span<const Operation>{
|
||||
exit_j.output.op(), exit_k.output.op()}),
|
||||
identity_i, one_outer);
|
||||
auto next_iteration_i =
|
||||
@ -929,7 +929,7 @@ TEST(FunctionalizeControlFlow, Complex) {
|
||||
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
|
||||
auto add_i =
|
||||
ops::Add(scope.WithOpName("outer/add")
|
||||
.WithControlDependencies(gtl::ArraySlice<Operation>{
|
||||
.WithControlDependencies(absl::Span<const Operation>{
|
||||
while_op[0].op(), while_op[1].op()}),
|
||||
identity_i, one_outer);
|
||||
|
||||
@ -991,10 +991,10 @@ TEST(FunctionalizeControlFlow, Complex) {
|
||||
auto assign = ops::AssignAddVariableOp(
|
||||
scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
|
||||
|
||||
auto one =
|
||||
ops::Const<int32>(scope.WithOpName("outer/inner/One")
|
||||
auto one = ops::Const<int32>(
|
||||
scope.WithOpName("outer/inner/One")
|
||||
.WithControlDependencies(
|
||||
gtl::ArraySlice<Operation>{assign.operation}),
|
||||
absl::Span<const Operation>{assign.operation}),
|
||||
1);
|
||||
auto add_j =
|
||||
ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
|
||||
|
@ -23,7 +23,7 @@ namespace {
|
||||
|
||||
void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
|
||||
DataType input_dtype, const TensorShape& input_tensor_shape,
|
||||
gtl::ArraySlice<int64> block_shape,
|
||||
absl::Span<const int64> block_shape,
|
||||
const xla::Literal& crops) {
|
||||
const int input_rank = input_tensor_shape.dims();
|
||||
const gtl::InlinedVector<int64, 4> input_shape =
|
||||
@ -34,7 +34,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
|
||||
ctx, input_rank >= 1 + block_rank,
|
||||
errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
|
||||
" instead of ", input_rank));
|
||||
gtl::ArraySlice<int64> remainder_shape(input_shape);
|
||||
absl::Span<const int64> remainder_shape(input_shape);
|
||||
remainder_shape.remove_prefix(1 + block_rank);
|
||||
|
||||
OP_REQUIRES(
|
||||
|
@ -36,8 +36,8 @@ namespace {
|
||||
explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
|
||||
xla::XlaOp Computation( \
|
||||
XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
|
||||
const gtl::ArraySlice<int64>& lhs_shape, const xla::XlaOp& rhs, \
|
||||
const gtl::ArraySlice<int64>& rhs_shape, \
|
||||
const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs, \
|
||||
const absl::Span<const int64>& rhs_shape, \
|
||||
const BCast& broadcast_helper, \
|
||||
const std::vector<int64>& extend_dimensions) override { \
|
||||
xla::XlaBuilder* b = ctx->builder(); \
|
||||
|
@ -57,8 +57,8 @@ class XlaBinaryOp : public XlaOpKernel {
|
||||
// in the XLA documentation.
|
||||
virtual xla::XlaOp Computation(
|
||||
XlaOpKernelContext* ctx, const xla::XlaOp& lhs,
|
||||
const gtl::ArraySlice<int64>& lhs_shape, const xla::XlaOp& rhs,
|
||||
const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
|
||||
const absl::Span<const int64>& lhs_shape, const xla::XlaOp& rhs,
|
||||
const absl::Span<const int64>& rhs_shape, const BCast& broadcast_helper,
|
||||
const std::vector<int64>& extend_dimensions) = 0;
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override;
|
||||
|
@ -29,7 +29,7 @@ namespace {
|
||||
|
||||
// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
|
||||
xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size,
|
||||
gtl::ArraySlice<int64> other_dims,
|
||||
absl::Span<const int64> other_dims,
|
||||
xla::PrimitiveType element_type) {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
// Create two matrices that have the following forms, and compare them:
|
||||
@ -177,7 +177,7 @@ class MatrixDiagOp : public XlaOpKernel {
|
||||
|
||||
int last_dim = dims.size() - 1;
|
||||
int64 last_dim_size = input_shape.dim_size(last_dim);
|
||||
tensorflow::gtl::ArraySlice<int64> other_dims(dims);
|
||||
absl::Span<const int64> other_dims(dims);
|
||||
other_dims.remove_suffix(1);
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
|
@ -78,7 +78,7 @@ struct ResizeConvolutionDims {
|
||||
std::vector<int64> stride;
|
||||
};
|
||||
ResizeConvolutionDims ComputeResizeConvolutionParameters(
|
||||
gtl::ArraySlice<int64> in_size, gtl::ArraySlice<int64> out_size,
|
||||
absl::Span<const int64> in_size, absl::Span<const int64> out_size,
|
||||
bool align_corners) {
|
||||
CHECK_EQ(in_size.size(), out_size.size());
|
||||
int num_spatial_dims = in_size.size();
|
||||
@ -147,7 +147,7 @@ std::vector<float> Make1DKernel(int64 n) {
|
||||
const int64 kMax2DKernelSize = 16;
|
||||
|
||||
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
|
||||
gtl::ArraySlice<int64> kernel_size,
|
||||
absl::Span<const int64> kernel_size,
|
||||
int64 channels) {
|
||||
xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
|
||||
|
||||
@ -165,7 +165,7 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
|
||||
}
|
||||
|
||||
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
|
||||
gtl::ArraySlice<int64> kernel_size,
|
||||
absl::Span<const int64> kernel_size,
|
||||
int64 channels, int64 dim) {
|
||||
xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
|
||||
|
||||
|
@ -155,7 +155,8 @@ class RandomShuffleOp : public XlaOpKernel {
|
||||
xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
|
||||
|
||||
// Swap the indices at i and swaps[i].
|
||||
auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
|
||||
auto swap_body_fn = [&](xla::XlaOp i,
|
||||
absl::Span<const xla::XlaOp> loop_vars,
|
||||
xla::XlaBuilder* builder)
|
||||
-> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
auto swaps = loop_vars[0];
|
||||
|
@ -66,7 +66,7 @@ class SelectOp : public XlaOpKernel {
|
||||
// XLA. It seems we have to broadcast on the left and then Reshape
|
||||
// to get the dimensions in the right order.
|
||||
const auto dim_sizes = then_shape.dim_sizes();
|
||||
gtl::ArraySlice<int64> bdims = dim_sizes;
|
||||
absl::Span<const int64> bdims = dim_sizes;
|
||||
bdims.remove_prefix(1);
|
||||
cond_handle = xla::Broadcast(cond_handle, bdims);
|
||||
|
||||
|
@ -23,7 +23,7 @@ namespace {
|
||||
|
||||
void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
|
||||
DataType input_dtype, const TensorShape& input_tensor_shape,
|
||||
gtl::ArraySlice<int64> block_shape,
|
||||
absl::Span<const int64> block_shape,
|
||||
const xla::Literal& paddings) {
|
||||
const int input_rank = input_tensor_shape.dims();
|
||||
const gtl::InlinedVector<int64, 4> input_shape =
|
||||
@ -34,7 +34,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
|
||||
ctx, input_rank >= 1 + block_rank,
|
||||
errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
|
||||
" instead of ", input_rank));
|
||||
gtl::ArraySlice<int64> remainder_shape(input_shape);
|
||||
absl::Span<const int64> remainder_shape(input_shape);
|
||||
remainder_shape.remove_prefix(1 + block_rank);
|
||||
|
||||
OP_REQUIRES(
|
||||
|
@ -122,7 +122,7 @@ Status GetTensorArrayShape(const XlaResource* resource,
|
||||
// relevant slice of 'operand'.
|
||||
xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
|
||||
const xla::XlaOp& update,
|
||||
const gtl::ArraySlice<int64>& update_dims,
|
||||
absl::Span<const int64> update_dims,
|
||||
const xla::XlaOp& start_indices) {
|
||||
xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims);
|
||||
xla::XlaOp sum = xla::Add(current, update);
|
||||
|
@ -64,7 +64,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
||||
xla::XlaOp l = xla::ZerosLike(a);
|
||||
|
||||
// Construct the for loop body to iterate over rows.
|
||||
auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
|
||||
auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
|
||||
xla::XlaBuilder* body_builder)
|
||||
-> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
xla::Shape col_shape;
|
||||
|
@ -65,9 +65,9 @@ namespace {
|
||||
// return (v, tau, beta)
|
||||
// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
|
||||
// overflows in the norm/beta calculations. Perhaps do the same here.
|
||||
xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice<int64> batch_dims,
|
||||
const int64 m, xla::XlaOp* v, xla::XlaOp* tau,
|
||||
xla::XlaOp* beta) {
|
||||
xla::Status House(xla::XlaOp x, xla::XlaOp k,
|
||||
absl::Span<const int64> batch_dims, const int64 m,
|
||||
xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) {
|
||||
xla::XlaBuilder* const builder = x.builder();
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
|
||||
const xla::PrimitiveType type = x_shape.element_type();
|
||||
@ -173,7 +173,7 @@ xla::StatusOr<QRBlockResult> QRBlock(
|
||||
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
|
||||
|
||||
auto qr_body_fn =
|
||||
[&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
|
||||
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
|
||||
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
auto a = values[0];
|
||||
auto vs = values[1];
|
||||
@ -255,7 +255,7 @@ xla::StatusOr<QRBlockResult> QRBlock(
|
||||
// There is no need to return Y since at termination of the loop it is equal to
|
||||
// vs.
|
||||
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
||||
xla::PrimitiveType type, gtl::ArraySlice<int64> batch_dims, xla::XlaOp vs,
|
||||
xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs,
|
||||
xla::XlaOp taus, int64 m, int64 n,
|
||||
xla::PrecisionConfigProto::Precision precision) {
|
||||
std::vector<int64> batch_dim_indices(batch_dims.size());
|
||||
@ -263,7 +263,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
||||
int64 n_index = batch_dims.size() + 1;
|
||||
|
||||
auto body_fn =
|
||||
[&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
|
||||
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
|
||||
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
auto w = values[0];
|
||||
auto y = values[1];
|
||||
|
@ -40,9 +40,9 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
|
||||
TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
|
||||
gtl::ArraySlice<int64> indices_dims =
|
||||
absl::Span<const int64> indices_dims =
|
||||
xla::AsInt64Slice(indices_shape.dimensions());
|
||||
gtl::ArraySlice<int64> buffer_dims =
|
||||
absl::Span<const int64> buffer_dims =
|
||||
xla::AsInt64Slice(buffer_shape.dimensions());
|
||||
|
||||
// If the indices are N-dimensional, the minor dimension of indices contains
|
||||
@ -107,7 +107,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
|
||||
// index = dynamic-slice(indices, i)
|
||||
// update = dynamic-slice(updates, i)
|
||||
// buffer = dynamic-update-slice(buffer, update, index)
|
||||
auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
|
||||
auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
|
||||
xla::XlaBuilder* body_builder) {
|
||||
auto indices = loop_vars[0];
|
||||
auto updates = loop_vars[1];
|
||||
|
@ -113,8 +113,8 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
return xla::ConstantLiteral(builder, literal);
|
||||
}
|
||||
|
||||
xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
|
||||
gtl::ArraySlice<int64> end) {
|
||||
xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
|
||||
absl::Span<const int64> end) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_RET_CHECK(start.size() == end.size());
|
||||
@ -144,8 +144,8 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
|
||||
gtl::ArraySlice<int64> ys) {
|
||||
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
absl::Span<const int64> ys) {
|
||||
std::vector<int64> output(xs.size() + ys.size());
|
||||
std::copy(xs.begin(), xs.end(), output.begin());
|
||||
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
|
||||
@ -153,8 +153,8 @@ std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
|
||||
}
|
||||
|
||||
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
|
||||
gtl::ArraySlice<xla::XlaOp> starts,
|
||||
gtl::ArraySlice<int64> sizes) {
|
||||
absl::Span<const xla::XlaOp> starts,
|
||||
absl::Span<const int64> sizes) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
@ -173,7 +173,7 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
|
||||
}
|
||||
|
||||
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
|
||||
gtl::ArraySlice<int64> start) {
|
||||
absl::Span<const int64> start) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
|
||||
@ -191,7 +191,7 @@ xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
|
||||
}
|
||||
|
||||
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
gtl::ArraySlice<int64> start) {
|
||||
absl::Span<const int64> start) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
@ -206,13 +206,13 @@ xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
}
|
||||
|
||||
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
gtl::ArraySlice<xla::XlaOp> starts) {
|
||||
absl::Span<const xla::XlaOp> starts) {
|
||||
auto padded_starts = PrependZerosInMajorDims(x, starts);
|
||||
return xla::DynamicUpdateSlice(x, update, padded_starts);
|
||||
}
|
||||
|
||||
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
|
||||
gtl::ArraySlice<xla::XlaOp> starts) {
|
||||
absl::Span<const xla::XlaOp> starts) {
|
||||
xla::XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
|
||||
|
@ -31,7 +31,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
|
||||
// prepended until the array is length n_dims.
|
||||
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
|
||||
gtl::ArraySlice<xla::XlaOp> starts);
|
||||
absl::Span<const xla::XlaOp> starts);
|
||||
|
||||
// Returns a integer scalar constant of 'type' with 'value'.
|
||||
// If 'type' is complex, returns a real value with zero imaginary component.
|
||||
@ -41,33 +41,33 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
// Builds a vector of zeros of length rank(x) with the last values being
|
||||
// those in `starts`.
|
||||
xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
|
||||
gtl::ArraySlice<xla::XlaOp> starts);
|
||||
absl::Span<const xla::XlaOp> starts);
|
||||
|
||||
// Performs a slice in the minor dimensions of a Tensor.
|
||||
xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
|
||||
gtl::ArraySlice<int64> end);
|
||||
xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span<const int64> start,
|
||||
absl::Span<const int64> end);
|
||||
|
||||
// Returns the concatenation of `xs` and `ys`.
|
||||
std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
|
||||
gtl::ArraySlice<int64> ys);
|
||||
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
absl::Span<const int64> ys);
|
||||
|
||||
// Performs a dynamic slice in the minor dimensions of a Tensor.
|
||||
xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
|
||||
gtl::ArraySlice<xla::XlaOp> starts,
|
||||
gtl::ArraySlice<int64> sizes);
|
||||
absl::Span<const xla::XlaOp> starts,
|
||||
absl::Span<const int64> sizes);
|
||||
|
||||
// Updates a slice of 'x', i.e.,
|
||||
// x[start[0], ..., start[n]] = update
|
||||
xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
|
||||
gtl::ArraySlice<int64> start);
|
||||
absl::Span<const int64> start);
|
||||
|
||||
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
|
||||
// x[..., start[0], ..., start[n]] = update
|
||||
xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
gtl::ArraySlice<int64> start);
|
||||
absl::Span<const int64> start);
|
||||
|
||||
xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
|
||||
gtl::ArraySlice<xla::XlaOp> starts);
|
||||
absl::Span<const xla::XlaOp> starts);
|
||||
|
||||
// Transposes a stack of matrices `x` by swapping the last two dimensions.
|
||||
xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
|
||||
|
@ -24,7 +24,7 @@ namespace tensorflow {
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
const LoopConditionFunction& condition_function,
|
||||
const LoopBodyFunction& body_function,
|
||||
gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
|
||||
absl::Span<const xla::XlaOp> initial_values, StringPiece name,
|
||||
xla::XlaBuilder* builder) {
|
||||
int arity = initial_values.size();
|
||||
std::vector<xla::Shape> var_shapes;
|
||||
@ -84,15 +84,15 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
||||
const ForEachIndexBodyFunction& body_function,
|
||||
gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
|
||||
absl::Span<const xla::XlaOp> initial_values, StringPiece name,
|
||||
xla::XlaBuilder* builder) {
|
||||
auto while_cond_fn =
|
||||
[&](gtl::ArraySlice<xla::XlaOp> values,
|
||||
[&](absl::Span<const xla::XlaOp> values,
|
||||
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
|
||||
return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
|
||||
num_iterations));
|
||||
};
|
||||
auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values,
|
||||
auto while_body_fn = [&](absl::Span<const xla::XlaOp> values,
|
||||
xla::XlaBuilder* body_builder)
|
||||
-> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
xla::XlaOp iteration = values[0];
|
||||
|
@ -29,14 +29,14 @@ namespace tensorflow {
|
||||
|
||||
// Function that builds a loop condition. Takes as input a sequence of input
|
||||
// values, and returns a boolean value representing if the condition succeeds.
|
||||
typedef std::function<xla::StatusOr<xla::XlaOp>(gtl::ArraySlice<xla::XlaOp>,
|
||||
typedef std::function<xla::StatusOr<xla::XlaOp>(absl::Span<const xla::XlaOp>,
|
||||
xla::XlaBuilder*)>
|
||||
LoopConditionFunction;
|
||||
|
||||
// Function that builds a loop body. Takes as input a sequence of input values
|
||||
// and returns a sequence of output values.
|
||||
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
||||
gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
|
||||
absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
|
||||
LoopBodyFunction;
|
||||
|
||||
// Helper function for building an XLA while loop, where the values carried by
|
||||
@ -50,7 +50,7 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
const LoopConditionFunction& condition_function,
|
||||
const LoopBodyFunction& body_function,
|
||||
gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
|
||||
absl::Span<const xla::XlaOp> initial_values, StringPiece name,
|
||||
xla::XlaBuilder* builder);
|
||||
|
||||
// Builds an XLA loop that repeats a computation `num_iterations` times.
|
||||
@ -59,13 +59,13 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
// (current iteration number, loop-carried values), and returns an updated
|
||||
// vector of the loop-carried values.
|
||||
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
||||
xla::XlaOp, gtl::ArraySlice<xla::XlaOp>, xla::XlaBuilder*)>
|
||||
xla::XlaOp, absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
|
||||
ForEachIndexBodyFunction;
|
||||
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
||||
const ForEachIndexBodyFunction& body_function,
|
||||
gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
|
||||
absl::Span<const xla::XlaOp> initial_values, StringPiece name,
|
||||
xla::XlaBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -49,8 +49,7 @@ Status HostTensorToMutableBorrowingLiteral(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HostTensorsToBorrowingLiteralTuple(
|
||||
tensorflow::gtl::ArraySlice<Tensor> host_tensors,
|
||||
Status HostTensorsToBorrowingLiteralTuple(absl::Span<const Tensor> host_tensors,
|
||||
xla::BorrowingLiteral* literal) {
|
||||
std::vector<const char*> buf_ptrs;
|
||||
buf_ptrs.reserve(host_tensors.size());
|
||||
|
@ -43,8 +43,7 @@ Status HostTensorToMutableBorrowingLiteral(
|
||||
|
||||
// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
|
||||
// owned by 'host_tensors'.
|
||||
Status HostTensorsToBorrowingLiteralTuple(
|
||||
tensorflow::gtl::ArraySlice<Tensor> host_tensors,
|
||||
Status HostTensorsToBorrowingLiteralTuple(absl::Span<const Tensor> host_tensors,
|
||||
xla::BorrowingLiteral* literal);
|
||||
|
||||
// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
|
||||
|
@ -28,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
|
||||
{
|
||||
std::vector<int64> int64_values = {1, 2, 3};
|
||||
std::unique_ptr<xla::Literal> int64_values_literal =
|
||||
xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
|
||||
xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
|
||||
Tensor host_tensor;
|
||||
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
|
||||
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
|
||||
@ -49,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
|
||||
Tensor host_tensor;
|
||||
std::vector<int32> int32_values = {10, 11};
|
||||
std::unique_ptr<xla::Literal> int32_values_literal =
|
||||
xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
|
||||
xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
|
||||
EXPECT_TRUE(
|
||||
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
|
||||
.ok());
|
||||
|
@ -835,8 +835,8 @@ Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key,
|
||||
|
||||
namespace {
|
||||
|
||||
void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
|
||||
gtl::ArraySlice<TensorShape> shapes,
|
||||
void SetTransfer(const string& key, absl::Span<const DataType> types,
|
||||
absl::Span<const TensorShape> shapes,
|
||||
tf2xla::HostTransferMetadata* transfer) {
|
||||
transfer->set_key(key);
|
||||
CHECK(types.size() == shapes.size());
|
||||
@ -850,8 +850,8 @@ void SetTransfer(const string& key, gtl::ArraySlice<DataType> types,
|
||||
} // namespace
|
||||
|
||||
Status XlaCompiler::SetDeviceToHostMetadata(
|
||||
const string& key, gtl::ArraySlice<DataType> types,
|
||||
gtl::ArraySlice<TensorShape> shapes) {
|
||||
const string& key, absl::Span<const DataType> types,
|
||||
absl::Span<const TensorShape> shapes) {
|
||||
if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Duplicate calls to SetDeviceToHostMetadata with key ", key);
|
||||
@ -877,8 +877,8 @@ Status XlaCompiler::GetDeviceToHostShapes(
|
||||
}
|
||||
|
||||
Status XlaCompiler::SetHostToDeviceMetadata(
|
||||
const string& key, gtl::ArraySlice<DataType> types,
|
||||
gtl::ArraySlice<TensorShape> shapes) {
|
||||
const string& key, absl::Span<const DataType> types,
|
||||
absl::Span<const TensorShape> shapes) {
|
||||
if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Duplicate calls to SetHostToDeviceMetadata with key ", key);
|
||||
|
@ -351,8 +351,8 @@ class XlaCompiler {
|
||||
// Sets the shapes and types for the device to host transfer associated with
|
||||
// 'key'.
|
||||
Status SetDeviceToHostMetadata(const string& key,
|
||||
gtl::ArraySlice<DataType> types,
|
||||
gtl::ArraySlice<TensorShape> shapes);
|
||||
absl::Span<const DataType> types,
|
||||
absl::Span<const TensorShape> shapes);
|
||||
|
||||
// Gets the shapes the device to host transfer associated with 'key'.
|
||||
Status GetDeviceToHostShapes(const string& key,
|
||||
@ -361,8 +361,8 @@ class XlaCompiler {
|
||||
// Sets the shapes and types for the host to device transfer associated with
|
||||
// 'key'.
|
||||
Status SetHostToDeviceMetadata(const string& key,
|
||||
gtl::ArraySlice<DataType> types,
|
||||
gtl::ArraySlice<TensorShape> shapes);
|
||||
absl::Span<const DataType> types,
|
||||
absl::Span<const TensorShape> shapes);
|
||||
|
||||
// In order to avoid deadlocks from dependencies in host computations, it can
|
||||
// be necessary to enforce a partial order on the execution of HostCompute
|
||||
|
@ -119,7 +119,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
|
||||
}
|
||||
|
||||
/* static */ Status XlaHelpers::ReshapeLiteral(
|
||||
const xla::Literal& input, gtl::ArraySlice<int64> dimensions,
|
||||
const xla::Literal& input, absl::Span<const int64> dimensions,
|
||||
xla::Literal* output) {
|
||||
if (xla::ShapeUtil::IsTuple(input.shape())) {
|
||||
return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
|
||||
|
@ -50,7 +50,7 @@ class XlaHelpers {
|
||||
// Reshapes literal 'input' to have 'shape'. Both the original shape and
|
||||
// 'shape' must contain the same number of elements.
|
||||
static Status ReshapeLiteral(const xla::Literal& input,
|
||||
gtl::ArraySlice<int64> shape,
|
||||
absl::Span<const int64> shape,
|
||||
xla::Literal* output);
|
||||
|
||||
// Returns the argmax of `input` along `axis`. `output_type` is the type to
|
||||
|
@ -119,7 +119,7 @@ Status XlaOpKernelContext::ConstantInput(StringPiece name,
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
int index, gtl::ArraySlice<int64> new_dims,
|
||||
int index, absl::Span<const int64> new_dims,
|
||||
xla::Literal* constant_literal) {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
TensorShape new_shape(new_dims);
|
||||
|
@ -113,7 +113,7 @@ class XlaOpKernelContext {
|
||||
// cannot be evaluated, e.g., because it depends on unbound parameters,
|
||||
// returns a non-Ok status. If InputShape(index).num_elements() !=
|
||||
// new_shape.num_elements(), returns an error status.
|
||||
Status ConstantInputReshaped(int index, gtl::ArraySlice<int64> new_shape,
|
||||
Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
|
||||
xla::Literal* constant_literal);
|
||||
|
||||
// Converts a constant scalar int32 or int64 tensor into an int64.
|
||||
|
@ -105,7 +105,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
|
||||
/* static */ void XlaOpRegistry::RegisterBackend(
|
||||
const string& compilation_device_name,
|
||||
gtl::ArraySlice<DataType> supported_types, BackendOpFilter op_filter) {
|
||||
absl::Span<const DataType> supported_types, BackendOpFilter op_filter) {
|
||||
XlaOpRegistry& registry = Instance();
|
||||
mutex_lock lock(registry.mutex_);
|
||||
auto result = registry.backends_.emplace(compilation_device_name, Backend());
|
||||
@ -382,7 +382,7 @@ XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) {
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
||||
gtl::ArraySlice<StringPiece> devices) {
|
||||
absl::Span<const StringPiece> devices) {
|
||||
registration_->has_device_whitelist = true;
|
||||
for (StringPiece device : devices) {
|
||||
registration_->device_whitelist.emplace(device);
|
||||
@ -415,7 +415,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
|
||||
StringPiece attr_name, gtl::ArraySlice<DataType> allowed) {
|
||||
StringPiece attr_name, absl::Span<const DataType> allowed) {
|
||||
std::set<DataType>& types =
|
||||
registration_->type_constraints[string(attr_name)];
|
||||
for (DataType t : allowed) {
|
||||
@ -452,7 +452,7 @@ XlaOpRegistrar::XlaOpRegistrar(
|
||||
}
|
||||
|
||||
XlaBackendRegistrar::XlaBackendRegistrar(
|
||||
StringPiece name, gtl::ArraySlice<DataType> types,
|
||||
StringPiece name, absl::Span<const DataType> types,
|
||||
XlaOpRegistry::BackendOpFilter op_filter) {
|
||||
XlaOpRegistry& registry = XlaOpRegistry::Instance();
|
||||
registry.RegisterBackend(string(name), types, op_filter);
|
||||
|
@ -94,7 +94,7 @@ class XlaOpRegistry {
|
||||
// the device; it may optionally modify the KernelDef.
|
||||
typedef bool (*BackendOpFilter)(KernelDef* kdef);
|
||||
static void RegisterBackend(const string& compilation_device_name,
|
||||
gtl::ArraySlice<DataType> supported_types,
|
||||
absl::Span<const DataType> supported_types,
|
||||
BackendOpFilter op_filter);
|
||||
|
||||
// Returns the names of the registered backends.
|
||||
@ -236,7 +236,7 @@ class XlaOpRegistrationBuilder {
|
||||
|
||||
// Specifies a whitelist of devices on which the operator may run.
|
||||
XlaOpRegistrationBuilder& Device(StringPiece devices);
|
||||
XlaOpRegistrationBuilder& Device(gtl::ArraySlice<StringPiece> devices);
|
||||
XlaOpRegistrationBuilder& Device(absl::Span<const StringPiece> devices);
|
||||
|
||||
// Specifies a type constraint for a type variable attribute. Each constraint
|
||||
// specifies the set of types that the type variable may assume.
|
||||
@ -244,7 +244,7 @@ class XlaOpRegistrationBuilder {
|
||||
DataType allowed);
|
||||
|
||||
XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
|
||||
gtl::ArraySlice<DataType> allowed);
|
||||
absl::Span<const DataType> allowed);
|
||||
|
||||
// Specifies that a dummy copy of this operator should not be registered on
|
||||
// XLA_* devices, but may be used during compilation.
|
||||
@ -288,7 +288,7 @@ class XlaOpRegistrar {
|
||||
|
||||
class XlaBackendRegistrar {
|
||||
public:
|
||||
XlaBackendRegistrar(StringPiece name, gtl::ArraySlice<DataType> types,
|
||||
XlaBackendRegistrar(StringPiece name, absl::Span<const DataType> types,
|
||||
XlaOpRegistry::BackendOpFilter op_filter = nullptr);
|
||||
};
|
||||
|
||||
|
@ -97,12 +97,11 @@ class Array {
|
||||
using value_type = T;
|
||||
|
||||
// Creates a new array with the specified dimensions.
|
||||
explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
|
||||
: Array(sizes, T()) {}
|
||||
explicit Array(absl::Span<const int64> sizes) : Array(sizes, T()) {}
|
||||
|
||||
// Creates a new array with the specified dimensions and specified value for
|
||||
// every cell.
|
||||
Array(tensorflow::gtl::ArraySlice<int64> sizes, T value)
|
||||
Array(absl::Span<const int64> sizes, T value)
|
||||
: sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
|
||||
Fill(value);
|
||||
}
|
||||
@ -301,7 +300,7 @@ class Array {
|
||||
|
||||
// Invokes a callback with the (indices, value_ptr) for each cell in the
|
||||
// array.
|
||||
void Each(std::function<void(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
|
||||
void Each(std::function<void(absl::Span<const int64>, T*)> f) {
|
||||
std::vector<int64> index(sizes_.size());
|
||||
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
|
||||
f(index, &values_[i]);
|
||||
@ -309,8 +308,7 @@ class Array {
|
||||
}
|
||||
|
||||
// Invokes a callback with the (indices, value) for each cell in the array.
|
||||
void Each(
|
||||
std::function<void(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
|
||||
void Each(std::function<void(absl::Span<const int64>, T)> f) const {
|
||||
std::vector<int64> index(sizes_.size());
|
||||
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
|
||||
f(index, values_[i]);
|
||||
@ -320,8 +318,7 @@ class Array {
|
||||
// Invokes a callback with the (indices, value_ptr) for each cell in the
|
||||
// array. If a callback returns a non-OK status, returns that else returns
|
||||
// Status::OK().
|
||||
Status EachStatus(
|
||||
std::function<Status(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
|
||||
Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) {
|
||||
std::vector<int64> index(sizes_.size());
|
||||
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
|
||||
Status s = f(index, &values_[i]);
|
||||
@ -335,8 +332,7 @@ class Array {
|
||||
// Invokes a callback with the (indices, value) for each cell in the array.
|
||||
// If a callback returns a non-OK status, returns that else returns
|
||||
// Status::OK().
|
||||
Status EachStatus(
|
||||
std::function<Status(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
|
||||
Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const {
|
||||
std::vector<int64> index(sizes_.size());
|
||||
for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
|
||||
Status s = f(index, values_[i]);
|
||||
@ -377,13 +373,13 @@ class Array {
|
||||
|
||||
// Returns the value at the cell specified by the indexes. The number of
|
||||
// arguments have to match with the number of dimensions for the array.
|
||||
const T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) const {
|
||||
const T& operator()(absl::Span<const int64> indexes) const {
|
||||
return values_[calculate_index(indexes)];
|
||||
}
|
||||
|
||||
// Returns the value at the cell specified by the indexes. The number of
|
||||
// arguments have to match with the number of dimensions for the array.
|
||||
T& operator()(tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
T& operator()(absl::Span<const int64> indexes) {
|
||||
return values_[calculate_index(indexes)];
|
||||
}
|
||||
|
||||
@ -438,8 +434,8 @@ class Array {
|
||||
bool operator!=(const Array<T>& other) const { return !(*this == other); }
|
||||
|
||||
// Performs the equivalent of a slice operation on this array.
|
||||
Array<T> Slice(tensorflow::gtl::ArraySlice<int64> starts,
|
||||
tensorflow::gtl::ArraySlice<int64> limits) const {
|
||||
Array<T> Slice(absl::Span<const int64> starts,
|
||||
absl::Span<const int64> limits) const {
|
||||
CHECK_EQ(starts.size(), num_dimensions());
|
||||
CHECK_EQ(limits.size(), num_dimensions());
|
||||
|
||||
@ -464,7 +460,7 @@ class Array {
|
||||
|
||||
// Performs the equivalent of a DynamicUpdateSlice in-place on this array.
|
||||
void UpdateSlice(const Array<T>& from,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices) {
|
||||
absl::Span<const int64> start_indices) {
|
||||
CHECK_EQ(from.num_dimensions(), num_dimensions());
|
||||
std::vector<int64> limit_indices;
|
||||
std::transform(start_indices.begin(), start_indices.end(),
|
||||
@ -484,7 +480,7 @@ class Array {
|
||||
|
||||
// Performs an in-place reshape, modifying the dimensions but not the
|
||||
// underlying data.
|
||||
void Reshape(tensorflow::gtl::ArraySlice<int64> new_dimensions) {
|
||||
void Reshape(absl::Span<const int64> new_dimensions) {
|
||||
int64 old_num_elements = num_elements();
|
||||
sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
|
||||
CHECK_EQ(num_elements(), old_num_elements);
|
||||
|
@ -27,8 +27,7 @@ namespace {
|
||||
// Given an Array4D and a 4-tuple index, computes the linear index into the
|
||||
// array idx represents.
|
||||
template <typename T>
|
||||
int64 Array4DLinearIndex(const Array4D<T>& arr,
|
||||
tensorflow::gtl::ArraySlice<int64> idx) {
|
||||
int64 Array4DLinearIndex(const Array4D<T>& arr, absl::Span<const int64> idx) {
|
||||
EXPECT_EQ(4, idx.size());
|
||||
return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() +
|
||||
idx[0] * arr.n2() * arr.n3() * arr.n4());
|
||||
@ -51,9 +50,8 @@ TEST(Array4dTest, FillCtor) {
|
||||
EXPECT_EQ(fullof7.n3(), 4);
|
||||
EXPECT_EQ(fullof7.n4(), 5);
|
||||
|
||||
fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
|
||||
EXPECT_EQ(*cell, 7);
|
||||
});
|
||||
fullof7.Each(
|
||||
[](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 7); });
|
||||
}
|
||||
|
||||
TEST(Array4dTest, ContainerCtor) {
|
||||
@ -69,7 +67,7 @@ TEST(Array4dTest, ContainerCtor) {
|
||||
EXPECT_EQ(arr.n3(), 4);
|
||||
EXPECT_EQ(arr.n4(), 5);
|
||||
|
||||
arr.Each([&arr](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
|
||||
arr.Each([&arr](absl::Span<const int64> idx, int* cell) {
|
||||
EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx));
|
||||
});
|
||||
}
|
||||
@ -129,21 +127,19 @@ TEST(Array3dTest, InitializerListCtorHalf) {
|
||||
|
||||
TEST(Array4dTest, Fill) {
|
||||
Array4D<int> fullof7(2, 3, 4, 5, 7);
|
||||
fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
|
||||
EXPECT_EQ(*cell, 7);
|
||||
});
|
||||
fullof7.Each(
|
||||
[](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 7); });
|
||||
|
||||
fullof7.Fill(11);
|
||||
fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
|
||||
EXPECT_EQ(*cell, 11);
|
||||
});
|
||||
fullof7.Each(
|
||||
[](absl::Span<const int64> idx, int* cell) { EXPECT_EQ(*cell, 11); });
|
||||
}
|
||||
|
||||
TEST(Array4dTest, FillWithMultiples) {
|
||||
Array4D<float> arr(2, 3, 4, 5);
|
||||
arr.FillWithMultiples(2.0f);
|
||||
|
||||
arr.Each([&arr](tensorflow::gtl::ArraySlice<int64> idx, float* cell) {
|
||||
arr.Each([&arr](absl::Span<const int64> idx, float* cell) {
|
||||
EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx));
|
||||
});
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ TEST(ArrayTest, Each) {
|
||||
arr.FillWithMultiples(1);
|
||||
|
||||
int64 each_count = 0, each_sum = 0;
|
||||
arr.Each([&](tensorflow::gtl::ArraySlice<int64> idx, int cell) {
|
||||
arr.Each([&](absl::Span<const int64> idx, int cell) {
|
||||
int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2];
|
||||
EXPECT_EQ(lin_idx, cell);
|
||||
each_count++;
|
||||
|
@ -163,8 +163,7 @@ Status Client::ResetDevice() {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
|
||||
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
|
||||
const ExecutionOptions* execution_options,
|
||||
ExecutionProfile* execution_profile) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -212,8 +211,7 @@ StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
|
||||
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
|
||||
const ExecutionOptions* execution_options,
|
||||
ExecutionProfile* execution_profile) {
|
||||
ExecuteGraphRequest request;
|
||||
@ -252,7 +250,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
|
||||
tensorflow::gtl::ArraySlice<XlaComputationInstance> computations) {
|
||||
absl::Span<const XlaComputationInstance> computations) {
|
||||
ExecuteGraphParallelRequest request;
|
||||
|
||||
for (const XlaComputationInstance& computation : computations) {
|
||||
|
@ -53,7 +53,7 @@ class Client {
|
||||
// will be filled with profile data from the execution.
|
||||
StatusOr<std::unique_ptr<GlobalData>> Execute(
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
|
||||
absl::Span<GlobalData* const> arguments,
|
||||
const ExecutionOptions* execution_options = nullptr,
|
||||
ExecutionProfile* execution_profile = nullptr);
|
||||
|
||||
@ -82,7 +82,7 @@ class Client {
|
||||
// from each computation.
|
||||
//
|
||||
StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
|
||||
tensorflow::gtl::ArraySlice<XlaComputationInstance> computations);
|
||||
absl::Span<const XlaComputationInstance> computations);
|
||||
|
||||
// Requests device_count device handles available on the target. The returned
|
||||
// device handles are used to specify the devices to execute the computations
|
||||
@ -134,7 +134,7 @@ class Client {
|
||||
// Execute() and Transfer().
|
||||
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
|
||||
absl::Span<GlobalData* const> arguments,
|
||||
const ExecutionOptions* execution_options = nullptr,
|
||||
ExecutionProfile* execution_profile = nullptr);
|
||||
|
||||
|
@ -23,7 +23,7 @@ namespace xla {
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyClient::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const absl::Span<const AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options,
|
||||
std::unique_ptr<AotCompilationMetadata>* metadata) {
|
||||
std::vector<CompileOnlyService::AotXlaComputationInstance> service_instances;
|
||||
|
@ -52,7 +52,7 @@ class CompileOnlyClient : public Client {
|
||||
// code. |metadata|, if provided, is populated during compilation.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const absl::Span<const AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options,
|
||||
std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
|
||||
|
||||
|
@ -86,7 +86,7 @@ class ExecutableBuildOptions {
|
||||
void add_disabled_hlo_pass(absl::string_view pass_name) {
|
||||
disabled_hlo_passes_.push_back(std::string(pass_name));
|
||||
}
|
||||
const tensorflow::gtl::ArraySlice<std::string> disabled_hlo_passes() const {
|
||||
const absl::Span<const std::string> disabled_hlo_passes() const {
|
||||
return disabled_hlo_passes_;
|
||||
}
|
||||
|
||||
|
@ -69,8 +69,7 @@ std::array<float, 6> kErfUCoefficient = {
|
||||
|
||||
// Evaluate the polynomial given coefficients and `x`.
|
||||
// N.B. Coefficients should be supplied in decreasing order.
|
||||
XlaOp EvaluatePolynomial(XlaOp x,
|
||||
tensorflow::gtl::ArraySlice<float> coefficients) {
|
||||
XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
|
||||
XlaOp poly = ScalarLike(x, 0.0);
|
||||
for (float c : coefficients) {
|
||||
poly = poly * x + ScalarLike(x, c);
|
||||
|
@ -34,8 +34,7 @@ XlaOp Reciprocal(XlaOp operand);
|
||||
|
||||
// Evaluates a polynomial given coefficients and `x`.
|
||||
// N.B. Coefficients should be supplied in decreasing order.
|
||||
XlaOp EvaluatePolynomial(XlaOp x,
|
||||
tensorflow::gtl::ArraySlice<float> coefficients);
|
||||
XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients);
|
||||
|
||||
// Computes an approximation of the error function complement (1 - erf(x)).
|
||||
XlaOp Erfc(XlaOp x);
|
||||
|
@ -39,7 +39,7 @@ XlaOp GetMatrixDiagonal(XlaOp x) {
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
const int64 m = shape.dimensions(n_dims - 2);
|
||||
const int64 n = shape.dimensions(n_dims - 1);
|
||||
tensorflow::gtl::ArraySlice<int64> major_dims =
|
||||
absl::Span<const int64> major_dims =
|
||||
AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
|
||||
auto a = Iota(builder, U32, n);
|
||||
auto b = Iota(builder, U32, m);
|
||||
@ -66,7 +66,7 @@ XlaOp Triangle(XlaOp x, bool lower) {
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
const int64 m = shape.dimensions(n_dims - 2);
|
||||
const int64 n = shape.dimensions(n_dims - 1);
|
||||
tensorflow::gtl::ArraySlice<int64> major_dims =
|
||||
absl::Span<const int64> major_dims =
|
||||
AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
|
||||
auto a = Iota(builder, U32, n);
|
||||
auto b = Iota(builder, U32, m);
|
||||
|
@ -26,11 +26,9 @@ namespace {
|
||||
// element of an image by the count of elements that contributed to that
|
||||
// element during pooling.
|
||||
XlaOp AvgPoolDivideByCountWithGeneralPadding(
|
||||
XlaOp sums, PrimitiveType dtype,
|
||||
tensorflow::gtl::ArraySlice<int64> input_shape,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
|
||||
tensorflow::gtl::ArraySlice<int64> ksize,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
XlaOp sums, PrimitiveType dtype, absl::Span<const int64> input_shape,
|
||||
absl::Span<const std::pair<int64, int64>> spatial_padding,
|
||||
absl::Span<const int64> ksize, absl::Span<const int64> stride,
|
||||
const TensorFormat& data_format) {
|
||||
// The padding shouldn't be included in the counts. We use another
|
||||
// ReduceWindow to find the right counts.
|
||||
@ -73,8 +71,8 @@ XlaOp AvgPoolDivideByCountWithGeneralPadding(
|
||||
|
||||
// Sums all elements in the window specified by 'kernel_size' and 'stride'.
|
||||
XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride,
|
||||
const TensorFormat& data_format) {
|
||||
XlaBuilder* b = operand.builder();
|
||||
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
@ -89,8 +87,8 @@ XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
|
||||
|
||||
// Creates a padding configuration out of spatial padding values.
|
||||
PaddingConfig MakeSpatialPaddingConfig(
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
|
||||
int num_spatial_dims, tensorflow::gtl::ArraySlice<int64> stride,
|
||||
absl::Span<const std::pair<int64, int64>> spatial_padding,
|
||||
int num_spatial_dims, absl::Span<const int64> stride,
|
||||
const TensorFormat& data_format) {
|
||||
PaddingConfig padding_config;
|
||||
for (int i = 0; i < 2 + num_spatial_dims; ++i) {
|
||||
@ -107,11 +105,10 @@ PaddingConfig MakeSpatialPaddingConfig(
|
||||
return padding_config;
|
||||
}
|
||||
|
||||
XlaOp AvgPoolDivideByCount(
|
||||
XlaOp pooled, tensorflow::gtl::ArraySlice<int64> input_size,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span<const int64> input_size,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
PrimitiveType dtype, const TensorFormat& data_format,
|
||||
bool counts_include_padding) {
|
||||
if (counts_include_padding) {
|
||||
@ -133,8 +130,8 @@ XlaOp AvgPoolDivideByCount(
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride, Padding padding,
|
||||
const TensorFormat& data_format) {
|
||||
XlaBuilder* b = operand.builder();
|
||||
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
@ -147,9 +144,9 @@ XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const TensorFormat& data_format,
|
||||
const bool counts_include_padding) {
|
||||
XlaBuilder* b = operand.builder();
|
||||
@ -173,9 +170,8 @@ XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
|
||||
tensorflow::gtl::ArraySlice<int64> input_size,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride, Padding padding,
|
||||
const TensorFormat& data_format) {
|
||||
const int num_spatial_dims = kernel_size.size() - 2;
|
||||
std::vector<int64> input_spatial_dimensions;
|
||||
@ -193,12 +189,12 @@ std::vector<std::pair<int64, int64>> MakeSpatialPadding(
|
||||
stride_spatial_dimensions, padding);
|
||||
}
|
||||
|
||||
XlaOp AvgPoolGrad(
|
||||
XlaOp out_backprop, tensorflow::gtl::ArraySlice<int64> gradients_size,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
|
||||
const TensorFormat& data_format, const bool counts_include_padding) {
|
||||
XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
|
||||
absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride,
|
||||
absl::Span<const std::pair<int64, int64>> spatial_padding,
|
||||
const TensorFormat& data_format,
|
||||
const bool counts_include_padding) {
|
||||
XlaBuilder* b = out_backprop.builder();
|
||||
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
const int num_dims = kernel_size.size();
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
class TensorFormat {
|
||||
public:
|
||||
TensorFormat(int batch_dimension, int feature_dimension,
|
||||
tensorflow::gtl::ArraySlice<int64> spatial_dimensions)
|
||||
absl::Span<const int64> spatial_dimensions)
|
||||
: batch_dimension_(batch_dimension),
|
||||
feature_dimension_(feature_dimension),
|
||||
spatial_dimensions_(spatial_dimensions.begin(),
|
||||
@ -49,32 +49,31 @@ class TensorFormat {
|
||||
};
|
||||
|
||||
// Computes the max pool of 'operand'.
|
||||
XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride, Padding padding,
|
||||
const TensorFormat& data_format);
|
||||
|
||||
// Computes the average pool of 'operand'.
|
||||
XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const TensorFormat& data_format,
|
||||
const bool counts_include_padding);
|
||||
|
||||
// Returns the list of low and high padding elements in each spatial dimension
|
||||
// for the given 'padding' specification.
|
||||
std::vector<std::pair<int64, int64>> MakeSpatialPadding(
|
||||
tensorflow::gtl::ArraySlice<int64> input_size,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride, Padding padding,
|
||||
const TensorFormat& data_format);
|
||||
|
||||
// Computes the average pool gradient.
|
||||
XlaOp AvgPoolGrad(
|
||||
XlaOp out_backprop, tensorflow::gtl::ArraySlice<int64> gradients_size,
|
||||
tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> spatial_padding,
|
||||
const TensorFormat& data_format, const bool counts_include_padding);
|
||||
XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
|
||||
absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride,
|
||||
absl::Span<const std::pair<int64, int64>> spatial_padding,
|
||||
const TensorFormat& data_format,
|
||||
const bool counts_include_padding);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -32,8 +32,8 @@ TensorFormat MakeNCHWFormat(int num_spatial_dims) {
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> MakeGeneralPadding(
|
||||
XlaOp input, tensorflow::gtl::ArraySlice<int64> kernel_size,
|
||||
tensorflow::gtl::ArraySlice<int64> stride, Padding padding,
|
||||
XlaOp input, absl::Span<const int64> kernel_size,
|
||||
absl::Span<const int64> stride, Padding padding,
|
||||
const xla::TensorFormat& data_format) {
|
||||
XlaBuilder* b = input.builder();
|
||||
Shape operand_shape = b->GetShape(input).ValueOrDie();
|
||||
@ -46,7 +46,7 @@ std::vector<std::pair<int64, int64>> MakeGeneralPadding(
|
||||
// Add singleton batch and feature dimensions to spatial dimensions, according
|
||||
// to 'data_format' specification.
|
||||
std::vector<int64> ExpandWithBatchAndFeatureDimensions(
|
||||
tensorflow::gtl::ArraySlice<int64> spatial_dim_sizes,
|
||||
absl::Span<const int64> spatial_dim_sizes,
|
||||
const xla::TensorFormat& data_format) {
|
||||
const int num_spatial_dims = spatial_dim_sizes.size();
|
||||
std::vector<int64> tensor_sizes(num_spatial_dims + 2, 1);
|
||||
|
@ -51,7 +51,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
|
||||
}
|
||||
|
||||
Status LocalExecutable::ValidateExecutionOptions(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
const ExecutableRunOptions& run_options, const Backend& backend) {
|
||||
const ComputationLayout& computation_layout =
|
||||
executable_->module_config().entry_computation_layout();
|
||||
@ -140,7 +140,7 @@ Status LocalExecutable::ValidateExecutionOptions(
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateExecutionOptions(arguments, run_options, *backend_));
|
||||
@ -177,7 +177,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
|
||||
const absl::Span<const ShapedBuffer* const> arguments) {
|
||||
executable_->hlo_snapshot()->set_execution_platform(
|
||||
backend_->platform()->Name());
|
||||
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
|
||||
@ -191,7 +191,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
|
||||
}
|
||||
|
||||
Status LocalExecutable::RecordArguments(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloSnapshot* hlo_snapshot) {
|
||||
hlo_snapshot->clear_arguments();
|
||||
for (const ShapedBuffer* argument : arguments) {
|
||||
@ -245,7 +245,7 @@ Backend* LocalClient::mutable_backend() {
|
||||
|
||||
StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
|
||||
const XlaComputation& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const absl::Span<const Shape* const> argument_layouts,
|
||||
const ExecutableBuildOptions& options) {
|
||||
ExecutableBuildOptions updated_options = options;
|
||||
if (options.device_ordinal() == -1) {
|
||||
|
@ -40,7 +40,7 @@ class LocalExecutable {
|
||||
// Run the compiled computation with the given arguments and options and
|
||||
// return the result.
|
||||
StatusOr<ScopedShapedBuffer> Run(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options);
|
||||
|
||||
// Return the options used to build the executable.
|
||||
@ -63,7 +63,7 @@ class LocalExecutable {
|
||||
// The given ExecutableRunOptions override any values from legacy_flags
|
||||
// (TF_XLA_FLAGS environment variable).
|
||||
Status ValidateExecutionOptions(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
const ExecutableRunOptions& run_options, const Backend& backend);
|
||||
|
||||
// Records the computation in a SessionModule proto with the arguments used to
|
||||
@ -73,12 +73,11 @@ class LocalExecutable {
|
||||
// (TF_XLA_FLAGS environment variable).
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAndDump(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
|
||||
const absl::Span<const ShapedBuffer* const> arguments);
|
||||
|
||||
// Records the arguments used to invoke the computation in a SessionModule
|
||||
// proto.
|
||||
Status RecordArguments(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloSnapshot* hlo_snapshot);
|
||||
|
||||
// Records the result of the computation in a SessionModule proto.
|
||||
@ -120,7 +119,7 @@ class LocalClient : public Client {
|
||||
// (TF_XLA_FLAGS environment variable).
|
||||
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
|
||||
const XlaComputation& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const absl::Span<const Shape* const> argument_layouts,
|
||||
const ExecutableBuildOptions& options);
|
||||
|
||||
// Copy the literal data to the device with the given ordinal and return as a
|
||||
|
@ -23,10 +23,9 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status ValidatePaddingValues(
|
||||
tensorflow::gtl::ArraySlice<int64> input_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides) {
|
||||
Status ValidatePaddingValues(absl::Span<const int64> input_dimensions,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides) {
|
||||
bool ok = input_dimensions.size() == window_dimensions.size() &&
|
||||
input_dimensions.size() == window_strides.size();
|
||||
if (!ok) {
|
||||
@ -40,9 +39,9 @@ Status ValidatePaddingValues(
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> MakePadding(
|
||||
tensorflow::gtl::ArraySlice<int64> input_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
|
||||
absl::Span<const int64> input_dimensions,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding) {
|
||||
TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions,
|
||||
window_strides));
|
||||
std::vector<std::pair<int64, int64>> low_high_padding;
|
||||
|
@ -41,10 +41,9 @@ enum class Padding {
|
||||
// Validates that the slices are acceptable for determining padding -- this can
|
||||
// be used to check the preconditions of MakePadding below to produce an error
|
||||
// message that can be returned to the user.
|
||||
Status ValidatePaddingValues(
|
||||
tensorflow::gtl::ArraySlice<int64> input_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides);
|
||||
Status ValidatePaddingValues(absl::Span<const int64> input_dimensions,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides);
|
||||
|
||||
// Returns the padding needed for the base area, given the base area dimensions,
|
||||
// window dimensions, strides, and the type of padding.
|
||||
@ -58,9 +57,9 @@ Status ValidatePaddingValues(
|
||||
// window_dimensions, and strides must match, which is equal to the number
|
||||
// of elements in the result vector.
|
||||
std::vector<std::pair<int64, int64>> MakePadding(
|
||||
tensorflow::gtl::ArraySlice<int64> input_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
|
||||
absl::Span<const int64> input_dimensions,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -90,7 +90,7 @@ StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
|
||||
}
|
||||
|
||||
StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
|
||||
tensorflow::gtl::ArraySlice<XlaOp> operands) const {
|
||||
absl::Span<const XlaOp> operands) const {
|
||||
std::vector<Shape> operand_shapes;
|
||||
for (const XlaOp& operand : operands) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
|
||||
@ -291,7 +291,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
|
||||
const Shape& shape, const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
|
||||
HloInstructionProto instr;
|
||||
@ -352,9 +352,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::BinaryOp(
|
||||
HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
|
||||
@ -448,12 +447,12 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -480,7 +479,7 @@ XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Call(const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> operands) {
|
||||
absl::Span<const XlaOp> operands) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
std::vector<const Shape*> operand_shape_ptrs;
|
||||
@ -515,8 +514,8 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Broadcast(
|
||||
const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
|
||||
XlaOp XlaBuilder::Broadcast(const XlaOp& operand,
|
||||
absl::Span<const int64> broadcast_sizes) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -541,7 +540,7 @@ XlaOp XlaBuilder::Broadcast(
|
||||
|
||||
XlaOp XlaBuilder::BroadcastInDim(
|
||||
const XlaOp& operand, const Shape& shape,
|
||||
const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
const absl::Span<const int64> broadcast_dimensions) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
return InDimBroadcast(shape, operand, broadcast_dimensions);
|
||||
});
|
||||
@ -556,9 +555,9 @@ StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Slice(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> strides) {
|
||||
absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
@ -593,7 +592,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
@ -631,7 +630,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
|
||||
int64 dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
@ -671,8 +670,8 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> new_sizes) {
|
||||
absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> new_sizes) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape,
|
||||
@ -686,7 +685,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> new_sizes) {
|
||||
absl::Span<const int64> new_sizes) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
|
||||
std::vector<int64> dimensions(shape.dimensions_size());
|
||||
@ -696,7 +695,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Collapse(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
absl::Span<const int64> dimensions) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (dimensions.size() <= 1) {
|
||||
// Not collapsing anything, trivially we can return the operand versus
|
||||
@ -706,8 +705,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
|
||||
|
||||
// Out-of-order collapse is not supported.
|
||||
// Checks that the collapsed dimensions are in order and consecutive.
|
||||
for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
|
||||
i < dimensions.size(); ++i) {
|
||||
for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
|
||||
if (dimensions[i] - 1 != dimensions[i - 1]) {
|
||||
return InvalidArgument(
|
||||
"Collapsed dimensions are not in consecutive order.");
|
||||
@ -758,7 +756,7 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
|
||||
XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
std::vector<const Shape*> operand_shape_ptrs;
|
||||
@ -792,32 +790,32 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -899,8 +897,8 @@ Status XlaBuilder::VerifyConvolution(
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
Padding padding, int64 feature_group_count,
|
||||
absl::Span<const int64> window_strides, Padding padding,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
return ConvWithGeneralDimensions(
|
||||
lhs, rhs, window_strides, padding,
|
||||
@ -909,9 +907,8 @@ XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConvWithGeneralPadding(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
return ConvGeneral(lhs, rhs, window_strides, padding,
|
||||
@ -920,9 +917,8 @@ XlaOp XlaBuilder::ConvWithGeneralPadding(
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConvWithGeneralDimensions(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
@ -957,9 +953,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConvGeneral(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
@ -969,11 +964,9 @@ XlaOp XlaBuilder::ConvGeneral(
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConvGeneralDilated(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
|
||||
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
@ -1013,11 +1006,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
|
||||
}
|
||||
|
||||
StatusOr<Window> XlaBuilder::MakeWindow(
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs_dilation) const {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation) const {
|
||||
const auto verify_size = [&](const size_t x, const char* x_name) {
|
||||
if (x == 0 || x == window_dimensions.size()) {
|
||||
return Status::OK();
|
||||
@ -1067,7 +1060,7 @@ StatusOr<Window> XlaBuilder::MakeWindow(
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
|
||||
const tensorflow::gtl::ArraySlice<int64> fft_length) {
|
||||
const absl::Span<const int64> fft_length) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
@ -1276,7 +1269,7 @@ XlaOp XlaBuilder::CreateToken() {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
|
||||
XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (tokens.empty()) {
|
||||
return InvalidArgument("AfterAll requires at least one operand");
|
||||
@ -1288,7 +1281,7 @@ XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const Shape& shape) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
@ -1304,9 +1297,8 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Complex(
|
||||
const XlaOp& real, const XlaOp& imag,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -1315,42 +1307,42 @@ XlaOp XlaBuilder::Conj(const XlaOp& operand) {
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -1358,22 +1350,21 @@ XlaOp XlaBuilder::Not(const XlaOp& operand) {
|
||||
return UnaryOp(HloOpcode::kNot, operand);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ShiftLeft(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ShiftRightArithmetic(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ShiftRightLogical(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
}
|
||||
@ -1382,9 +1373,8 @@ XlaOp XlaBuilder::Abs(const XlaOp& operand) {
|
||||
return UnaryOp(HloOpcode::kAbs, operand);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Atan2(
|
||||
const XlaOp& y, const XlaOp& x,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -1449,7 +1439,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) {
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Transpose(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> permutation) {
|
||||
absl::Span<const int64> permutation) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
@ -1464,7 +1454,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Rev(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
absl::Span<const int64> dimensions) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
@ -1506,7 +1496,7 @@ XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional<XlaOp> values,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -1544,10 +1534,10 @@ XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand,
|
||||
return TernaryOp(HloOpcode::kClamp, min, operand, max);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
|
||||
absl::Span<const int64> dimensions,
|
||||
absl::Span<const XlaOp> static_operands) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (!static_operands.empty()) {
|
||||
return Unimplemented("static_operands is not supported in Map");
|
||||
@ -1588,7 +1578,7 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> parameters,
|
||||
absl::Span<const XlaOp> parameters,
|
||||
const Shape& shape) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
@ -1649,7 +1639,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
|
||||
|
||||
XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
@ -1729,20 +1719,18 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Reduce(
|
||||
const XlaOp& operand, const XlaOp& init_value,
|
||||
XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
|
||||
return Reduce(tensorflow::gtl::ArraySlice<XlaOp>({operand}),
|
||||
tensorflow::gtl::ArraySlice<XlaOp>({init_value}), computation,
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
return Reduce(absl::Span<const XlaOp>({operand}),
|
||||
absl::Span<const XlaOp>({init_value}), computation,
|
||||
dimensions_to_reduce);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Reduce(
|
||||
tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> init_values,
|
||||
XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
@ -1785,11 +1773,11 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ReduceWindow(
|
||||
const XlaOp& operand, const XlaOp& init_value,
|
||||
XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
Padding padding) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
@ -1810,9 +1798,9 @@ XlaOp XlaBuilder::ReduceWindow(
|
||||
XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
|
||||
const XlaOp& operand, const XlaOp& init_value,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
@ -1907,8 +1895,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::CrossReplicaSum(
|
||||
const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups) {
|
||||
const XlaOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
|
||||
const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
|
||||
@ -1923,7 +1910,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
|
||||
|
||||
XlaOp XlaBuilder::CrossReplicaSum(
|
||||
const XlaOp& operand, const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
@ -2023,11 +2010,12 @@ XlaOp XlaBuilder::CollectivePermute(
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::SelectAndScatter(
|
||||
const XlaOp& operand, const XlaComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
|
||||
const XlaOp& source, const XlaOp& init_value,
|
||||
XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand,
|
||||
const XlaComputation& select,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
Padding padding, const XlaOp& source,
|
||||
const XlaOp& init_value,
|
||||
const XlaComputation& scatter) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
@ -2041,11 +2029,10 @@ XlaOp XlaBuilder::SelectAndScatter(
|
||||
|
||||
XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
|
||||
const XlaOp& operand, const XlaComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
const XlaOp& source, const XlaOp& init_value,
|
||||
const XlaComputation& scatter) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
|
||||
const XlaOp& init_value, const XlaComputation& scatter) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
@ -2428,9 +2415,9 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::AddInstruction(
|
||||
HloInstructionProto&& instr, HloOpcode opcode,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> operands) {
|
||||
StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
||||
HloOpcode opcode,
|
||||
absl::Span<const XlaOp> operands) {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
|
||||
const int64 handle = instructions_.size();
|
||||
@ -2504,14 +2491,12 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
|
||||
return builder->ConstantLiteral(literal);
|
||||
}
|
||||
|
||||
XlaOp Broadcast(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
|
||||
XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes) {
|
||||
return operand.builder()->Broadcast(operand, broadcast_sizes);
|
||||
}
|
||||
|
||||
XlaOp BroadcastInDim(
|
||||
const XlaOp& operand, const Shape& shape,
|
||||
const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape,
|
||||
const absl::Span<const int64> broadcast_dimensions) {
|
||||
return operand.builder()->BroadcastInDim(operand, shape,
|
||||
broadcast_dimensions);
|
||||
}
|
||||
@ -2521,26 +2506,22 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
|
||||
return operand.builder()->Pad(operand, padding_value, padding_config);
|
||||
}
|
||||
|
||||
XlaOp Reshape(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> new_sizes) {
|
||||
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> new_sizes) {
|
||||
return operand.builder()->Reshape(operand, dimensions, new_sizes);
|
||||
}
|
||||
|
||||
XlaOp Reshape(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> new_sizes) {
|
||||
XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes) {
|
||||
return operand.builder()->Reshape(operand, new_sizes);
|
||||
}
|
||||
|
||||
XlaOp Collapse(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions) {
|
||||
return operand.builder()->Collapse(operand, dimensions);
|
||||
}
|
||||
|
||||
XlaOp Slice(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> strides) {
|
||||
XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides) {
|
||||
return operand.builder()->Slice(operand, start_indices, limit_indices,
|
||||
strides);
|
||||
}
|
||||
@ -2552,7 +2533,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
|
||||
}
|
||||
|
||||
XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
|
||||
}
|
||||
|
||||
@ -2561,8 +2542,7 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
|
||||
return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
|
||||
}
|
||||
|
||||
XlaOp ConcatInDim(XlaBuilder* builder,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
|
||||
int64 dimension) {
|
||||
return builder->ConcatInDim(operands, dimension);
|
||||
}
|
||||
@ -2575,7 +2555,7 @@ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) {
|
||||
return pred.builder()->Select(pred, on_true, on_false);
|
||||
}
|
||||
|
||||
XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements) {
|
||||
XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
|
||||
return builder->Tuple(elements);
|
||||
}
|
||||
|
||||
@ -2584,32 +2564,32 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
|
||||
}
|
||||
|
||||
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Le(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -2626,7 +2606,7 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||
}
|
||||
|
||||
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
|
||||
absl::Span<const int64> window_strides, Padding padding,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
|
||||
@ -2634,9 +2614,8 @@ XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
|
||||
}
|
||||
|
||||
XlaOp ConvWithGeneralPadding(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
|
||||
@ -2645,9 +2624,8 @@ XlaOp ConvWithGeneralPadding(
|
||||
}
|
||||
|
||||
XlaOp ConvWithGeneralDimensions(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
return lhs.builder()->ConvWithGeneralDimensions(
|
||||
@ -2656,8 +2634,8 @@ XlaOp ConvWithGeneralDimensions(
|
||||
}
|
||||
|
||||
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
@ -2666,12 +2644,11 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||
precision_config_proto);
|
||||
}
|
||||
|
||||
XlaOp ConvGeneralDilated(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
|
||||
XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count,
|
||||
const PrecisionConfigProto* precision_config_proto) {
|
||||
@ -2681,7 +2658,7 @@ XlaOp ConvGeneralDilated(
|
||||
}
|
||||
|
||||
XlaOp Fft(const XlaOp& operand, FftType fft_type,
|
||||
tensorflow::gtl::ArraySlice<int64> fft_length) {
|
||||
absl::Span<const int64> fft_length) {
|
||||
return operand.builder()->Fft(operand, fft_type, fft_length);
|
||||
}
|
||||
|
||||
@ -2695,105 +2672,102 @@ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
|
||||
}
|
||||
|
||||
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> operands) {
|
||||
absl::Span<const XlaOp> operands) {
|
||||
return builder->Call(computation, operands);
|
||||
}
|
||||
|
||||
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
const Shape& shape) {
|
||||
absl::Span<const XlaOp> operands, const Shape& shape) {
|
||||
return builder->CustomCall(call_target_name, operands, shape);
|
||||
}
|
||||
|
||||
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return real.builder()->Complex(real, imag, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); }
|
||||
|
||||
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Add(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Div(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Max(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Min(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->And(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Or(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); }
|
||||
|
||||
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp ShiftRightArithmetic(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp ShiftRightLogical(
|
||||
const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
return operand.builder()->Reduce(operand, init_value, computation,
|
||||
dimensions_to_reduce);
|
||||
}
|
||||
|
||||
// Reduces several arrays simultaneously among the provided dimensions, given
|
||||
// "computation" as a reduction operator.
|
||||
XlaOp Reduce(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> init_values,
|
||||
XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
return builder->Reduce(operands, init_values, computation,
|
||||
dimensions_to_reduce);
|
||||
}
|
||||
@ -2805,9 +2779,8 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
|
||||
|
||||
XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
Padding padding) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding) {
|
||||
return operand.builder()->ReduceWindow(operand, init_value, computation,
|
||||
window_dimensions, window_strides,
|
||||
padding);
|
||||
@ -2816,22 +2789,21 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
|
||||
XlaOp ReduceWindowWithGeneralPadding(
|
||||
const XlaOp& operand, const XlaOp& init_value,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding) {
|
||||
return operand.builder()->ReduceWindowWithGeneralPadding(
|
||||
operand, init_value, computation, window_dimensions, window_strides,
|
||||
padding);
|
||||
}
|
||||
|
||||
XlaOp CrossReplicaSum(
|
||||
const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups) {
|
||||
XlaOp CrossReplicaSum(const XlaOp& operand,
|
||||
absl::Span<const ReplicaGroup> replica_groups) {
|
||||
return operand.builder()->CrossReplicaSum(operand, replica_groups);
|
||||
}
|
||||
|
||||
XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
const absl::optional<ChannelHandle>& channel_id) {
|
||||
return operand.builder()->CrossReplicaSum(operand, computation,
|
||||
replica_groups, channel_id);
|
||||
@ -2851,10 +2823,10 @@ XlaOp CollectivePermute(
|
||||
}
|
||||
|
||||
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
Padding padding, const XlaOp& source,
|
||||
const XlaOp& init_value, const XlaComputation& scatter) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding,
|
||||
const XlaOp& source, const XlaOp& init_value,
|
||||
const XlaComputation& scatter) {
|
||||
return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
|
||||
window_strides, padding, source,
|
||||
init_value, scatter);
|
||||
@ -2862,11 +2834,10 @@ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
|
||||
|
||||
XlaOp SelectAndScatterWithGeneralPadding(
|
||||
const XlaOp& operand, const XlaComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
const XlaOp& source, const XlaOp& init_value,
|
||||
const XlaComputation& scatter) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
|
||||
const XlaOp& init_value, const XlaComputation& scatter) {
|
||||
return operand.builder()->SelectAndScatterWithGeneralPadding(
|
||||
operand, select, window_dimensions, window_strides, padding, source,
|
||||
init_value, scatter);
|
||||
@ -2875,7 +2846,7 @@ XlaOp SelectAndScatterWithGeneralPadding(
|
||||
XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); }
|
||||
|
||||
XlaOp Atan2(const XlaOp& y, const XlaOp& x,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return y.builder()->Atan2(y, x, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -2908,7 +2879,7 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
|
||||
XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
|
||||
|
||||
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
|
||||
}
|
||||
|
||||
@ -2926,12 +2897,11 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
|
||||
|
||||
XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
|
||||
|
||||
XlaOp Transpose(const XlaOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> permutation) {
|
||||
XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation) {
|
||||
return operand.builder()->Transpose(operand, permutation);
|
||||
}
|
||||
|
||||
XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions) {
|
||||
return operand.builder()->Rev(operand, dimensions);
|
||||
}
|
||||
|
||||
@ -2943,10 +2913,9 @@ XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
|
||||
return min.builder()->Clamp(min, operand, max);
|
||||
}
|
||||
|
||||
XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
|
||||
const XlaComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
|
||||
XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation, absl::Span<const int64> dimensions,
|
||||
absl::Span<const XlaOp> static_operands) {
|
||||
return builder->Map(operands, computation, dimensions, static_operands);
|
||||
}
|
||||
|
||||
@ -2980,7 +2949,7 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
|
||||
|
||||
XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
return input.builder()->Gather(input, start_indices, dimension_numbers,
|
||||
slice_sizes);
|
||||
}
|
||||
@ -3036,7 +3005,7 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
|
||||
|
||||
XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
|
||||
|
||||
XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens) {
|
||||
XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
|
||||
return builder->AfterAll(tokens);
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -27,7 +27,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
/* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex(
|
||||
const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
const Shape& shape, absl::Span<const int64> multi_index) {
|
||||
DCHECK_EQ(shape.dimensions_size(), multi_index.size());
|
||||
// Padding and nested layouts not supported yet.
|
||||
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
|
||||
@ -118,8 +118,8 @@ namespace xla {
|
||||
return multi_index;
|
||||
}
|
||||
|
||||
/* static */ bool IndexUtil::BumpIndices(
|
||||
const Shape& shape, tensorflow::gtl::MutableArraySlice<int64> indices) {
|
||||
/* static */ bool IndexUtil::BumpIndices(const Shape& shape,
|
||||
absl::Span<int64> indices) {
|
||||
for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) {
|
||||
int64 limit = shape.dimensions(dimno);
|
||||
if (indices[dimno] + 1 < limit) {
|
||||
@ -149,8 +149,8 @@ namespace xla {
|
||||
return stride;
|
||||
}
|
||||
|
||||
/* static */ bool IndexUtil::IndexInBounds(
|
||||
const Shape& shape, tensorflow::gtl::ArraySlice<int64> index) {
|
||||
/* static */ bool IndexUtil::IndexInBounds(const Shape& shape,
|
||||
absl::Span<const int64> index) {
|
||||
int64 rank = ShapeUtil::Rank(shape);
|
||||
if (rank != index.size()) {
|
||||
return false;
|
||||
@ -163,9 +163,8 @@ namespace xla {
|
||||
return true;
|
||||
}
|
||||
|
||||
/* static */ int IndexUtil::CompareIndices(
|
||||
tensorflow::gtl::ArraySlice<int64> lhs,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs) {
|
||||
/* static */ int IndexUtil::CompareIndices(absl::Span<const int64> lhs,
|
||||
absl::Span<const int64> rhs) {
|
||||
int64 rank = lhs.size();
|
||||
CHECK_EQ(rhs.size(), rank);
|
||||
for (int64 dim = 0; dim < rank; ++dim) {
|
||||
|
@ -35,7 +35,7 @@ class IndexUtil {
|
||||
// on the shape and its layout. The first index in the multi_index is
|
||||
// dimension 0.
|
||||
static int64 MultidimensionalIndexToLinearIndex(
|
||||
const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index);
|
||||
const Shape& shape, absl::Span<const int64> multi_index);
|
||||
|
||||
// Converts a linear index into multidimensional index (eg {x, y, z}) based on
|
||||
// the shape and its layout. The first index in the returned multidimensional
|
||||
@ -58,8 +58,7 @@ class IndexUtil {
|
||||
//
|
||||
// Returns true iff the indices were successfully bumped; false if we've hit
|
||||
// the limit where it can no longer be bumped in-bounds.
|
||||
static bool BumpIndices(const Shape& shape,
|
||||
tensorflow::gtl::MutableArraySlice<int64> indices);
|
||||
static bool BumpIndices(const Shape& shape, absl::Span<int64> indices);
|
||||
|
||||
// Calculates the stride size (in number of elements, not byte size) of a
|
||||
// given logical shape dimension (from 0 to rank-1). If available, padded
|
||||
@ -71,15 +70,14 @@ class IndexUtil {
|
||||
|
||||
// Returns true iff the given multi-index is contained in the bounds for the
|
||||
// shape.
|
||||
static bool IndexInBounds(const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<int64> index);
|
||||
static bool IndexInBounds(const Shape& shape, absl::Span<const int64> index);
|
||||
|
||||
// Compares the given indices in lexicographic order. lhs[0] and rhs[0] are
|
||||
// compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger,
|
||||
// then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is
|
||||
// returned.
|
||||
static int CompareIndices(tensorflow::gtl::ArraySlice<int64> lhs,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs);
|
||||
static int CompareIndices(absl::Span<const int64> lhs,
|
||||
absl::Span<const int64> rhs);
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil);
|
||||
|
@ -56,7 +56,7 @@ void SetDefaultLayoutToContainer(
|
||||
} // namespace
|
||||
|
||||
/* static */ Layout LayoutUtil::MakeLayout(
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
|
||||
absl::Span<const int64> minor_to_major) {
|
||||
Layout layout;
|
||||
layout.set_format(DENSE);
|
||||
for (int64 dimension_number : minor_to_major) {
|
||||
@ -66,7 +66,7 @@ void SetDefaultLayoutToContainer(
|
||||
}
|
||||
|
||||
/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
|
||||
tensorflow::gtl::ArraySlice<int64> major_to_minor) {
|
||||
absl::Span<const int64> major_to_minor) {
|
||||
Layout layout;
|
||||
layout.set_format(DENSE);
|
||||
for (int i = major_to_minor.size() - 1; i >= 0; i--) {
|
||||
@ -307,7 +307,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::PaddedDimensions(
|
||||
/* static */ absl::Span<const int64> LayoutUtil::PaddedDimensions(
|
||||
const Shape& shape) {
|
||||
CHECK(IsDenseArray(shape));
|
||||
return AsInt64Slice(shape.layout().padded_dimensions());
|
||||
@ -363,13 +363,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
return protobuf_util::ProtobufEquals(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::MinorToMajor(
|
||||
/* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
|
||||
const Shape& shape) {
|
||||
CHECK(IsDenseArray(shape));
|
||||
return AsInt64Slice(shape.layout().minor_to_major());
|
||||
}
|
||||
|
||||
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::MinorToMajor(
|
||||
/* static */ absl::Span<const int64> LayoutUtil::MinorToMajor(
|
||||
const Layout& layout) {
|
||||
CHECK(layout.format() == DENSE);
|
||||
return AsInt64Slice(layout.minor_to_major());
|
||||
@ -472,7 +472,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
|
||||
}
|
||||
|
||||
/* static */ bool LayoutUtil::AreDimensionsConsecutive(
|
||||
const Layout& layout, tensorflow::gtl::ArraySlice<int64> dims) {
|
||||
const Layout& layout, absl::Span<const int64> dims) {
|
||||
CHECK(IsDense(layout));
|
||||
std::vector<int64> positions_in_layout;
|
||||
for (int64 dim : dims) {
|
||||
|
@ -34,11 +34,11 @@ class LayoutUtil {
|
||||
public:
|
||||
// Creates a layout with the given minor-to-major dimension order. (This is a
|
||||
// convenience function for protobuf construction.)
|
||||
static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
|
||||
static Layout MakeLayout(absl::Span<const int64> minor_to_major);
|
||||
|
||||
// Similar to MakeLayout, but take indices in reverse order.
|
||||
static Layout MakeLayoutFromMajorToMinor(
|
||||
tensorflow::gtl::ArraySlice<int64> major_to_minor);
|
||||
absl::Span<const int64> major_to_minor);
|
||||
|
||||
// Creates a sparse layout with the given maximum number of elements. (This is
|
||||
// a convenience function for protobuf construction.)
|
||||
@ -104,8 +104,7 @@ class LayoutUtil {
|
||||
|
||||
// Returns the padded_dimensions array for the given Shape. Requires that the
|
||||
// shape is an array and has a dense layout.
|
||||
static tensorflow::gtl::ArraySlice<int64> PaddedDimensions(
|
||||
const Shape& shape);
|
||||
static absl::Span<const int64> PaddedDimensions(const Shape& shape);
|
||||
|
||||
// Returns the given index of the padded_dimensions array for the given Shape.
|
||||
// Requires that the shape is an array and has a dense layout.
|
||||
@ -138,8 +137,8 @@ class LayoutUtil {
|
||||
|
||||
// Returns the minor_to_major array for the given Shape. Requires that the
|
||||
// shape is an array and has a dense layout.
|
||||
static tensorflow::gtl::ArraySlice<int64> MinorToMajor(const Shape& shape);
|
||||
static tensorflow::gtl::ArraySlice<int64> MinorToMajor(const Layout& layout);
|
||||
static absl::Span<const int64> MinorToMajor(const Shape& shape);
|
||||
static absl::Span<const int64> MinorToMajor(const Layout& layout);
|
||||
|
||||
// Major(0) is the most major logical dimension number, Major(1) is the
|
||||
// second-most-major logical dimension number and so on.
|
||||
@ -196,7 +195,7 @@ class LayoutUtil {
|
||||
// Returns whether the given dimensions are consecutive in the given layout,
|
||||
// not necessarily in the order given.
|
||||
static bool AreDimensionsConsecutive(const Layout& layout,
|
||||
tensorflow::gtl::ArraySlice<int64> dims);
|
||||
absl::Span<const int64> dims);
|
||||
|
||||
// Compute a hash for `layout`.
|
||||
static size_t Hash(const Layout& layout);
|
||||
|
@ -27,15 +27,15 @@ namespace {
|
||||
class LayoutUtilTest : public ::testing::Test {
|
||||
protected:
|
||||
Shape MakeShapeWithLayout(PrimitiveType element_type,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
|
||||
absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> minor_to_major) {
|
||||
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
|
||||
*shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
|
||||
return shape;
|
||||
}
|
||||
|
||||
Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
absl::Span<const int64> dimensions,
|
||||
int64 max_sparse_elements) {
|
||||
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
|
||||
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
|
||||
|
@ -73,7 +73,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) {
|
||||
|
||||
MutableLiteralBase::StrideConfig::StrideConfig(
|
||||
const Shape& source_shape, const Shape& dest_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions)
|
||||
absl::Span<const int64> dimensions)
|
||||
: dimensions(dimensions),
|
||||
base(dimensions.size(), 0),
|
||||
step(dimensions.size(), 1) {
|
||||
@ -197,14 +197,13 @@ SparseIndexArray* MutableLiteralBase::sparse_indices(
|
||||
|
||||
template <typename NativeT>
|
||||
Status MutableLiteralBase::CopySliceFromInternal(
|
||||
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_base,
|
||||
tensorflow::gtl::ArraySlice<int64> copy_size) {
|
||||
const LiteralBase& src_literal, absl::Span<const int64> src_base,
|
||||
absl::Span<const int64> dest_base, absl::Span<const int64> copy_size) {
|
||||
TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
|
||||
TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
|
||||
|
||||
auto linear_index = [](const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
absl::Span<const int64> multi_index) {
|
||||
return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
|
||||
};
|
||||
|
||||
@ -232,7 +231,7 @@ Status MutableLiteralBase::CopySliceFromInternal(
|
||||
MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(),
|
||||
copy_size);
|
||||
|
||||
auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
auto copy_proc = [&](absl::Span<const int64> indexes) {
|
||||
// Map from multi-dimensional index, to source index.
|
||||
std::transform(indexes.begin(), indexes.end(), src_base.begin(),
|
||||
src_indexes.begin(), std::plus<int64>());
|
||||
@ -257,10 +256,9 @@ Status MutableLiteralBase::CopySliceFromInternal(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MutableLiteralBase::CopyElementFrom(
|
||||
const LiteralSlice& src_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> src_index,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_index) {
|
||||
Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
|
||||
absl::Span<const int64> src_index,
|
||||
absl::Span<const int64> dest_index) {
|
||||
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
|
||||
const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
|
||||
src_literal.shape(), src_index);
|
||||
@ -355,9 +353,9 @@ namespace {
|
||||
// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
|
||||
// the array slices are indicated by dest_shape and src_shape respectively.
|
||||
template <typename NativeT>
|
||||
void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
|
||||
tensorflow::gtl::ArraySlice<NativeT> src,
|
||||
const Shape& dest_shape, const Shape& src_shape) {
|
||||
void CopyElementsBetween(absl::Span<NativeT> dest,
|
||||
absl::Span<const NativeT> src, const Shape& dest_shape,
|
||||
const Shape& src_shape) {
|
||||
CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
|
||||
if (ShapeUtil::IsZeroElementArray(dest_shape)) {
|
||||
return;
|
||||
@ -487,11 +485,10 @@ Status Literal::MoveFrom(Literal&& src_literal,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MutableLiteralBase::CopySliceFrom(
|
||||
const LiteralSlice& src_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> src_base,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_base,
|
||||
tensorflow::gtl::ArraySlice<int64> copy_size) {
|
||||
Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
|
||||
absl::Span<const int64> src_base,
|
||||
absl::Span<const int64> dest_base,
|
||||
absl::Span<const int64> copy_size) {
|
||||
TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
|
||||
TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
|
||||
<< ShapeUtil::HumanString(src_literal.shape());
|
||||
@ -591,8 +588,7 @@ std::unique_ptr<Literal> LiteralBase::Relayout(
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
|
||||
const Shape& result_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) const {
|
||||
const Shape& result_shape, absl::Span<const int64> dimensions) const {
|
||||
if (!ShapeUtil::IsArray(shape())) {
|
||||
return InvalidArgument("Broadcast only supports arrays.");
|
||||
}
|
||||
@ -615,7 +611,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
|
||||
|
||||
ShapeUtil::ForEachIndex(
|
||||
result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
|
||||
result_shape, [&](absl::Span<const int64> output_index) {
|
||||
for (int64 i = 0; i < dimensions.size(); ++i) {
|
||||
scratch_source_index[i] = output_index[dimensions[i]];
|
||||
}
|
||||
@ -632,7 +628,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) const {
|
||||
absl::Span<const int64> dimensions) const {
|
||||
if (!ShapeUtil::IsArray(shape())) {
|
||||
return InvalidArgument("Reshape does not support tuples.");
|
||||
}
|
||||
@ -661,7 +657,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> LiteralBase::Transpose(
|
||||
tensorflow::gtl::ArraySlice<int64> permutation) const {
|
||||
absl::Span<const int64> permutation) const {
|
||||
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
|
||||
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
|
||||
<< "Given permutation is not a permutation of dimension numbers";
|
||||
@ -700,12 +696,11 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
|
||||
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> LiteralBase::SliceInternal(
|
||||
const Shape& result_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices) const {
|
||||
const Shape& result_shape, absl::Span<const int64> start_indices) const {
|
||||
auto result_literal = absl::make_unique<Literal>(result_shape);
|
||||
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
|
||||
result_literal->EachCell<NativeT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) {
|
||||
[&](absl::Span<const int64> indices, NativeT /*value*/) {
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
|
||||
new_indices[i] = indices[i] + start_indices[i];
|
||||
}
|
||||
@ -716,8 +711,8 @@ std::unique_ptr<Literal> LiteralBase::SliceInternal(
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> LiteralBase::Slice(
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices) const {
|
||||
absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices) const {
|
||||
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
|
||||
|
||||
DimensionVector result_dimensions;
|
||||
@ -761,7 +756,7 @@ std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
|
||||
const ShapeIndex& shape_index) const {
|
||||
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
|
||||
CHECK(LayoutUtil::IsDenseArray(subshape));
|
||||
@ -858,7 +853,7 @@ string LiteralBase::GetSparseElementAsString(
|
||||
}
|
||||
|
||||
StatusOr<int64> LiteralBase::GetIntegralAsS64(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
||||
absl::Span<const int64> multi_index) const {
|
||||
CHECK(LayoutUtil::IsDenseArray(shape()));
|
||||
switch (shape().element_type()) {
|
||||
case PRED:
|
||||
@ -900,8 +895,8 @@ size_t LiteralBase::Hash() const {
|
||||
return hash_value;
|
||||
}
|
||||
|
||||
Status MutableLiteralBase::SetIntegralAsS64(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index, int64 value) {
|
||||
Status MutableLiteralBase::SetIntegralAsS64(absl::Span<const int64> multi_index,
|
||||
int64 value) {
|
||||
CHECK(LayoutUtil::IsDenseArray(shape()));
|
||||
switch (shape().element_type()) {
|
||||
case PRED:
|
||||
@ -929,7 +924,7 @@ Status MutableLiteralBase::SetIntegralAsS64(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
|
||||
absl::Span<const int64> LiteralBase::GetSparseIndex(
|
||||
int64 sparse_element_number, const ShapeIndex& shape_index) const {
|
||||
const Piece& p = piece(shape_index);
|
||||
CHECK_GE(sparse_element_number, 0);
|
||||
@ -998,7 +993,7 @@ void LiteralBase::Piece::SortSparseElementsInternal() {
|
||||
auto values = data<NativeT>();
|
||||
CHECK_LE(num_elements, values.size());
|
||||
sparse_indices()->SortWithValues(
|
||||
tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
|
||||
absl::Span<NativeT>(values.data(), num_elements));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -1064,8 +1059,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
||||
|
||||
CHECK(LayoutUtil::IsDenseArray(subshape));
|
||||
|
||||
auto element_to_string =
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
|
||||
auto element_to_string = [&](absl::Span<const int64> indices) -> string {
|
||||
PrimitiveType element_type = subshape.element_type();
|
||||
if (element_type == PRED) {
|
||||
// We display predicates in a densely packed form.
|
||||
@ -1160,7 +1154,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
|
||||
pieces->push_back(shape_to_string(subshape));
|
||||
pieces->push_back(" {");
|
||||
literal.EachCellAsString(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
|
||||
[&](absl::Span<const int64> indices, const string& value) {
|
||||
pieces->push_back(" ");
|
||||
pieces->push_back(value);
|
||||
});
|
||||
@ -1183,7 +1177,7 @@ string LiteralBase::ToString(bool print_layout) const {
|
||||
}
|
||||
|
||||
void LiteralBase::EachCellAsString(
|
||||
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
|
||||
const std::function<void(absl::Span<const int64> indices,
|
||||
const string& value)>& per_cell) const {
|
||||
if (ShapeUtil::IsZeroElementArray(shape())) {
|
||||
return;
|
||||
@ -1250,10 +1244,8 @@ std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
|
||||
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
|
||||
using NativeSrcT =
|
||||
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
|
||||
tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
|
||||
src_literal.data<NativeSrcT>();
|
||||
tensorflow::gtl::MutableArraySlice<complex64> dest_data =
|
||||
result_literal->data<complex64>();
|
||||
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
|
||||
absl::Span<complex64> dest_data = result_literal->data<complex64>();
|
||||
int64 num_elements = src_literal.element_count();
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
|
||||
@ -1397,7 +1389,7 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
|
||||
}
|
||||
|
||||
/* static */ Literal MutableLiteralBase::MoveIntoTuple(
|
||||
tensorflow::gtl::MutableArraySlice<Literal> elements) {
|
||||
absl::Span<Literal> elements) {
|
||||
std::vector<Shape> element_shapes;
|
||||
for (const Literal& element : elements) {
|
||||
element_shapes.push_back(element.shape());
|
||||
@ -1488,7 +1480,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const {
|
||||
namespace {
|
||||
|
||||
template <typename NativeT>
|
||||
static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
|
||||
static bool AllElementsEqualValue(absl::Span<const NativeT> data,
|
||||
NativeT value) {
|
||||
for (int64 i = 0; i < data.size(); ++i) {
|
||||
if (data[i] != value) {
|
||||
@ -1742,7 +1734,7 @@ bool LiteralBase::IsR1Iota() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
|
||||
bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
switch (shape().element_type()) {
|
||||
case U8:
|
||||
@ -1778,7 +1770,7 @@ namespace {
|
||||
|
||||
template <typename RepeatedFieldT, typename NativeT>
|
||||
void CopyToRepeatedField(RepeatedFieldT* dest,
|
||||
const tensorflow::gtl::ArraySlice<NativeT> src) {
|
||||
const absl::Span<const NativeT> src) {
|
||||
*dest = RepeatedFieldT(src.begin(), src.end());
|
||||
}
|
||||
|
||||
@ -1856,7 +1848,7 @@ void* LiteralBase::Piece::untyped_data() {
|
||||
namespace {
|
||||
|
||||
template <typename RepeatedFieldT, typename NativeT>
|
||||
Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
|
||||
Status CopyFromRepeatedField(absl::Span<NativeT> dest,
|
||||
const RepeatedFieldT& src) {
|
||||
if (dest.size() != src.size()) {
|
||||
return InvalidArgument(
|
||||
@ -2126,8 +2118,8 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
|
||||
root_piece_.set_subshape(shape_.get());
|
||||
}
|
||||
|
||||
BorrowingLiteral::BorrowingLiteral(
|
||||
tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
|
||||
BorrowingLiteral::BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
|
||||
const Shape& shape)
|
||||
: LiteralBase(), shape_(absl::make_unique<Shape>(shape)) {
|
||||
CHECK(ShapeUtil::IsTuple(*shape_));
|
||||
CHECK(!ShapeUtil::IsNestedTuple(*shape_));
|
||||
|
@ -70,13 +70,12 @@ class LiteralBase {
|
||||
// Serialize to proto.
|
||||
LiteralProto ToProto() const;
|
||||
|
||||
// Returns an ArraySlice of the array for this literal for the given NativeT
|
||||
// Returns a Span of the array for this literal for the given NativeT
|
||||
// (e.g., float). CHECKs if the subshape of the literal at the given
|
||||
// ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
|
||||
// to native type.
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::ArraySlice<NativeT> data(
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
absl::Span<const NativeT> data(const ShapeIndex& shape_index = {}) const;
|
||||
|
||||
// Returns a const pointer to the sparse index array. Returns nullptr if the
|
||||
// literal is not a sparse array.
|
||||
@ -100,12 +99,12 @@ class LiteralBase {
|
||||
// Gets an element in the literal at the given index. The multi_index is
|
||||
// CHECKed against the dimension sizes.
|
||||
template <typename NativeT>
|
||||
NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
NativeT Get(absl::Span<const int64> multi_index,
|
||||
const ShapeIndex& shape_index) const;
|
||||
// Overloads of Get for array literals. CHECKs if the literal is not
|
||||
// array-shaped and dense.
|
||||
template <typename NativeT>
|
||||
NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
|
||||
NativeT Get(absl::Span<const int64> multi_index) const;
|
||||
|
||||
// Returns the element value at index (0, ..., 0), however many zeroes are
|
||||
// required for that index.
|
||||
@ -114,7 +113,7 @@ class LiteralBase {
|
||||
|
||||
// As Get(), but determines the correct type and converts the value
|
||||
// into text.
|
||||
string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
string GetAsString(absl::Span<const int64> multi_index,
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
// As GetSparseElement(), but determines the correct type and converts the
|
||||
// value into text.
|
||||
@ -122,14 +121,13 @@ class LiteralBase {
|
||||
const ShapeIndex& shape_index = {}) const;
|
||||
// As Get(), but determines the correct type and converts the value into
|
||||
// int64. This literal must be an array.
|
||||
StatusOr<int64> GetIntegralAsS64(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) const;
|
||||
StatusOr<int64> GetIntegralAsS64(absl::Span<const int64> multi_index) const;
|
||||
|
||||
// Returns the multi-index of the element in a sparse literal at the given
|
||||
// sparse element number. The sparse element number is the position with in
|
||||
// the sparse array's list of (index, value) pairs, and is checked against the
|
||||
// total number of (index, value) pairs in the sparse array.
|
||||
tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
|
||||
absl::Span<const int64> GetSparseIndex(
|
||||
int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
|
||||
|
||||
// Returns the value of the element in a sparse literal at the given sparse
|
||||
@ -150,11 +148,11 @@ class LiteralBase {
|
||||
//
|
||||
// This literal must have a dense layout.
|
||||
void EachCellAsString(
|
||||
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
|
||||
const std::function<void(absl::Span<const int64> indices,
|
||||
const string& value)>& per_cell) const;
|
||||
template <typename NativeT>
|
||||
void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
|
||||
NativeT value)>
|
||||
void EachCell(
|
||||
std::function<void(absl::Span<const int64> indices, NativeT value)>
|
||||
per_cell) const;
|
||||
|
||||
// Returns whether every element in this literal is equal to value.
|
||||
@ -200,7 +198,7 @@ class LiteralBase {
|
||||
|
||||
// Returns whether this literal is zero at the specified index. This literal
|
||||
// must be an array with a dense layout.
|
||||
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
|
||||
bool IsZero(absl::Span<const int64> indices) const;
|
||||
|
||||
// Returns the count of the elements in the array at the given shape index in
|
||||
// this literal.
|
||||
@ -273,13 +271,12 @@ class LiteralBase {
|
||||
// implementation currently only supports monotonic dim0-major layouts.
|
||||
// This literal must be an array.
|
||||
StatusOr<std::unique_ptr<Literal>> Reshape(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) const;
|
||||
absl::Span<const int64> dimensions) const;
|
||||
|
||||
// Creates a new literal by broadcasting this literal with `dimensions` to
|
||||
// yield a literal of shape `result_shape`.
|
||||
StatusOr<std::unique_ptr<Literal>> Broadcast(
|
||||
const Shape& result_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) const;
|
||||
const Shape& result_shape, absl::Span<const int64> dimensions) const;
|
||||
|
||||
// Creates a new literal by reordering the dimensions of this literal.
|
||||
// The given `permutation` must be a permutation of the dimension numbers
|
||||
@ -288,8 +285,7 @@ class LiteralBase {
|
||||
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
|
||||
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
|
||||
// This literal must be an array.
|
||||
std::unique_ptr<Literal> Transpose(
|
||||
tensorflow::gtl::ArraySlice<int64> permutation) const;
|
||||
std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
|
||||
|
||||
// Creates a sub-array from this literal by extracting the indices
|
||||
// [start_index, limit_index) of each dimension. The result literal has the
|
||||
@ -297,9 +293,8 @@ class LiteralBase {
|
||||
// start_indices and limit_indices must be the rank of the literal, and the
|
||||
// indices follow the order of the dimensions.
|
||||
// This literal must be an array.
|
||||
std::unique_ptr<Literal> Slice(
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices) const;
|
||||
std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices) const;
|
||||
|
||||
// Creates a literal with a prepended dimension with bound "times"; e.g. a
|
||||
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
|
||||
@ -328,9 +323,9 @@ class LiteralBase {
|
||||
// Returns the buffer holding the array data for this piece as an array
|
||||
// slice. This piece must be array-shaped.
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::ArraySlice<NativeT> data() const;
|
||||
absl::Span<const NativeT> data() const;
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> data();
|
||||
absl::Span<NativeT> data();
|
||||
|
||||
// Returns the buffer holding the array data for this piece as a void*. This
|
||||
// piece must be array-shaped.
|
||||
@ -341,9 +336,9 @@ class LiteralBase {
|
||||
// is CHECKed against the dimension sizes of the array. This piece must be
|
||||
// array-shaped.
|
||||
template <typename NativeT>
|
||||
NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
|
||||
NativeT Get(absl::Span<const int64> index) const;
|
||||
template <typename NativeT>
|
||||
void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
|
||||
void Set(absl::Span<const int64> index, NativeT value);
|
||||
|
||||
// Gets/sets the buffer holding the array data.
|
||||
char* buffer() const { return buffer_; }
|
||||
@ -545,8 +540,7 @@ class LiteralBase {
|
||||
private:
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<Literal> SliceInternal(
|
||||
const Shape& result_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices) const;
|
||||
const Shape& result_shape, absl::Span<const int64> start_indices) const;
|
||||
};
|
||||
|
||||
// Abstract base class representing a mutable literal in XLA.
|
||||
@ -554,13 +548,12 @@ class MutableLiteralBase : public LiteralBase {
|
||||
public:
|
||||
virtual ~MutableLiteralBase() = 0;
|
||||
|
||||
// Returns a MutableArraySlice view of the array for this literal for the
|
||||
// Returns a Span view of the array for this literal for the
|
||||
// given NativeT (e.g., float). CHECKs if the subshape of the literal at the
|
||||
// given ShapeIndex is not array. See primitive_util.h for the mapping from
|
||||
// XLA type to native type.
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> data(
|
||||
const ShapeIndex& shape_index = {});
|
||||
absl::Span<NativeT> data(const ShapeIndex& shape_index = {});
|
||||
// Unhide const method from parent class.
|
||||
using LiteralBase::data;
|
||||
|
||||
@ -587,8 +580,7 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// are populated.
|
||||
template <typename NativeT>
|
||||
void PopulateSparse(SparseIndexArray indices,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values,
|
||||
bool sort = true);
|
||||
absl::Span<const NativeT> values, bool sort = true);
|
||||
|
||||
// Copy values from 'src_literal' rooted at 'src_shape_index' into this
|
||||
// literal rooted at 'dest_shape_index'. The subshape of this literal rooted
|
||||
@ -609,39 +601,38 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// corresponding base indices being 0.
|
||||
// This literal and 'src_literal' must be arrays.
|
||||
Status CopySliceFrom(const LiteralSlice& src_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> src_base,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_base,
|
||||
tensorflow::gtl::ArraySlice<int64> copy_size);
|
||||
absl::Span<const int64> src_base,
|
||||
absl::Span<const int64> dest_base,
|
||||
absl::Span<const int64> copy_size);
|
||||
|
||||
// Copies one element from src_literal[src_index] to (*this)[dest_index].
|
||||
Status CopyElementFrom(const LiteralSlice& src_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> src_index,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_index);
|
||||
absl::Span<const int64> src_index,
|
||||
absl::Span<const int64> dest_index);
|
||||
|
||||
// Sets an element in the literal at the given index. The multi_index is
|
||||
// CHECKed against the dimension sizes.
|
||||
template <typename NativeT>
|
||||
void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
const ShapeIndex& shape_index, NativeT value);
|
||||
void Set(absl::Span<const int64> multi_index, const ShapeIndex& shape_index,
|
||||
NativeT value);
|
||||
// Overloads of Set for array literals. CHECKs if the literal is not
|
||||
// array-shaped and dense.
|
||||
template <typename NativeT>
|
||||
void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
|
||||
void Set(absl::Span<const int64> multi_index, NativeT value);
|
||||
|
||||
// Appends the given element to the literal. If the elements are not appended
|
||||
// in sorted order, then SortSparseElements should be called before calling
|
||||
// other methods. This literal must have a sparse layout.
|
||||
template <typename NativeT>
|
||||
void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
NativeT value, const ShapeIndex& shape_index = {});
|
||||
void AppendSparseElement(absl::Span<const int64> multi_index, NativeT value,
|
||||
const ShapeIndex& shape_index = {});
|
||||
|
||||
// Sorts the elements in a sparse array.
|
||||
void SortSparseElements(const ShapeIndex& shape_index = {});
|
||||
|
||||
// As Set(), but truncates `value` to the literal element type before storing.
|
||||
// This literal must be an array.
|
||||
Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
int64 value);
|
||||
Status SetIntegralAsS64(absl::Span<const int64> multi_index, int64 value);
|
||||
|
||||
// Populate this literal with the given values. Examples:
|
||||
//
|
||||
@ -656,7 +647,7 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
|
||||
// array of S32.
|
||||
template <typename NativeT>
|
||||
void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
|
||||
void PopulateR1(absl::Span<const NativeT> values);
|
||||
void PopulateR1(const tensorflow::core::Bitmap& values);
|
||||
template <typename NativeT>
|
||||
void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
|
||||
@ -673,7 +664,7 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// in this literal object.
|
||||
//
|
||||
// generator must be a callable of the type
|
||||
// NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
|
||||
// NativeT(absl::Span<int64> indexes) or compatible.
|
||||
//
|
||||
// This literal must have a dense layout.
|
||||
template <typename NativeT, typename FnType>
|
||||
@ -693,8 +684,7 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// moved into the tuple elements of a new tuple-shaped Literal which is
|
||||
// returned. Upon return, each of the Literals in 'elements' is set to a nil
|
||||
// shape (empty tuple).
|
||||
static Literal MoveIntoTuple(
|
||||
tensorflow::gtl::MutableArraySlice<Literal> elements);
|
||||
static Literal MoveIntoTuple(absl::Span<Literal> elements);
|
||||
|
||||
// Serialize from a proto.
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
|
||||
@ -712,20 +702,20 @@ class MutableLiteralBase : public LiteralBase {
|
||||
// arguments one by one.
|
||||
template <typename NativeT>
|
||||
Status CopySliceFromInternal(const LiteralBase& src_literal,
|
||||
tensorflow::gtl::ArraySlice<int64> src_base,
|
||||
tensorflow::gtl::ArraySlice<int64> dest_base,
|
||||
tensorflow::gtl::ArraySlice<int64> copy_size);
|
||||
absl::Span<const int64> src_base,
|
||||
absl::Span<const int64> dest_base,
|
||||
absl::Span<const int64> copy_size);
|
||||
|
||||
// Utility structure which is used to create the optimal configuration for
|
||||
// a ShapeUtil::ForEachIndex() scan across two literals.
|
||||
struct StrideConfig {
|
||||
StrideConfig(const Shape& source_shape, const Shape& dest_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions);
|
||||
absl::Span<const int64> dimensions);
|
||||
|
||||
// The dimensions of the stride operation. Essentially every dimension
|
||||
// will be iterated from base[i] to base[i]+dimensions[i], in step[i]
|
||||
// steps.
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions;
|
||||
absl::Span<const int64> dimensions;
|
||||
DimensionVector base;
|
||||
DimensionVector step;
|
||||
int64 minor_dimension = 0;
|
||||
@ -854,7 +844,7 @@ class BorrowingLiteral : public LiteralBase {
|
||||
// This constructor is only used for array shapes.
|
||||
BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
|
||||
// Similar as above, except to be used for constructing non-nested tuples.
|
||||
BorrowingLiteral(tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs,
|
||||
BorrowingLiteral(absl::Span<const char* const> src_buf_ptrs,
|
||||
const Shape& shape);
|
||||
// TODO(b/79707221): adding constructors for nested tuples as well.
|
||||
|
||||
@ -874,7 +864,7 @@ class BorrowingLiteral : public LiteralBase {
|
||||
};
|
||||
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
|
||||
absl::Span<const NativeT> LiteralBase::Piece::data() const {
|
||||
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
|
||||
CHECK_EQ(subshape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<NativeT>())
|
||||
@ -882,12 +872,12 @@ tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
|
||||
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
|
||||
<< " type, but literal element type is "
|
||||
<< PrimitiveType_Name(subshape().element_type());
|
||||
return tensorflow::gtl::ArraySlice<NativeT>(
|
||||
reinterpret_cast<const NativeT*>(buffer()), element_count());
|
||||
return absl::Span<const NativeT>(reinterpret_cast<const NativeT*>(buffer()),
|
||||
element_count());
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
|
||||
absl::Span<NativeT> LiteralBase::Piece::data() {
|
||||
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
|
||||
CHECK_EQ(subshape().element_type(),
|
||||
primitive_util::NativeToPrimitiveType<NativeT>())
|
||||
@ -895,20 +885,19 @@ tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
|
||||
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
|
||||
<< " type, but literal element type is "
|
||||
<< PrimitiveType_Name(subshape().element_type());
|
||||
return tensorflow::gtl::MutableArraySlice<NativeT>(
|
||||
reinterpret_cast<NativeT*>(buffer()), element_count());
|
||||
return absl::Span<NativeT>(reinterpret_cast<NativeT*>(buffer()),
|
||||
element_count());
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
NativeT LiteralBase::Piece::Get(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
||||
NativeT LiteralBase::Piece::Get(absl::Span<const int64> multi_index) const {
|
||||
CHECK(LayoutUtil::IsDenseArray(subshape()));
|
||||
return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
|
||||
subshape(), multi_index)];
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
void LiteralBase::Piece::Set(absl::Span<const int64> multi_index,
|
||||
NativeT value) {
|
||||
CHECK(LayoutUtil::IsDenseArray(subshape()));
|
||||
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
|
||||
@ -916,39 +905,37 @@ void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
|
||||
absl::Span<const NativeT> LiteralBase::data(
|
||||
const ShapeIndex& shape_index) const {
|
||||
return piece(shape_index).data<NativeT>();
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> MutableLiteralBase::data(
|
||||
const ShapeIndex& shape_index) {
|
||||
absl::Span<NativeT> MutableLiteralBase::data(const ShapeIndex& shape_index) {
|
||||
return piece(shape_index).data<NativeT>();
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index,
|
||||
const ShapeIndex& shape_index) const {
|
||||
return piece(shape_index).Get<NativeT>(multi_index);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
inline NativeT LiteralBase::Get(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) const {
|
||||
inline NativeT LiteralBase::Get(absl::Span<const int64> multi_index) const {
|
||||
return root_piece().Get<NativeT>(multi_index);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
inline void MutableLiteralBase::Set(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index,
|
||||
const ShapeIndex& shape_index, NativeT value) {
|
||||
inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
|
||||
const ShapeIndex& shape_index,
|
||||
NativeT value) {
|
||||
return piece(shape_index).Set<NativeT>(multi_index, value);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
inline void MutableLiteralBase::Set(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value) {
|
||||
inline void MutableLiteralBase::Set(absl::Span<const int64> multi_index,
|
||||
NativeT value) {
|
||||
return root_piece().Set<NativeT>(multi_index, value);
|
||||
}
|
||||
|
||||
@ -967,7 +954,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
|
||||
|
||||
template <typename NativeT>
|
||||
void MutableLiteralBase::AppendSparseElement(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
|
||||
absl::Span<const int64> multi_index, NativeT value,
|
||||
const ShapeIndex& shape_index) {
|
||||
Piece& p = piece(shape_index);
|
||||
const Shape& subshape = p.subshape();
|
||||
@ -983,8 +970,7 @@ void MutableLiteralBase::AppendSparseElement(
|
||||
|
||||
template <typename NativeT>
|
||||
void LiteralBase::EachCell(
|
||||
std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
|
||||
NativeT value)>
|
||||
std::function<void(absl::Span<const int64> indices, NativeT value)>
|
||||
per_cell) const {
|
||||
if (ShapeUtil::IsZeroElementArray(shape())) {
|
||||
return;
|
||||
@ -996,8 +982,7 @@ void LiteralBase::EachCell(
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
inline void MutableLiteralBase::PopulateR1(
|
||||
tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
inline void MutableLiteralBase::PopulateR1(absl::Span<const NativeT> values) {
|
||||
CHECK(ShapeUtil::IsArray(shape()));
|
||||
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
|
||||
CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
|
||||
@ -1042,8 +1027,9 @@ void MutableLiteralBase::PopulateFromArray(const Array<NativeT>& values) {
|
||||
for (int dim = 0; dim < values.num_dimensions(); ++dim) {
|
||||
CHECK_EQ(values.dim(dim), shape().dimensions(dim));
|
||||
}
|
||||
values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
|
||||
NativeT value) { this->Set(indices, value); });
|
||||
values.Each([this](absl::Span<const int64> indices, NativeT value) {
|
||||
this->Set(indices, value);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -1062,8 +1048,8 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void MutableLiteralBase::PopulateSparse(
|
||||
SparseIndexArray indices, tensorflow::gtl::ArraySlice<NativeT> values,
|
||||
void MutableLiteralBase::PopulateSparse(SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values,
|
||||
bool sort) {
|
||||
CHECK(LayoutUtil::IsSparseArray(shape()));
|
||||
int rank = ShapeUtil::Rank(shape());
|
||||
@ -1074,7 +1060,7 @@ void MutableLiteralBase::PopulateSparse(
|
||||
CHECK_LE(num_elements, max_elements);
|
||||
CHECK_EQ(num_elements, indices.index_count());
|
||||
auto root_data = root_piece().data<NativeT>();
|
||||
// Piece::data() returns an ArraySlice of size equal to the number of indices
|
||||
// Piece::data() returns a Span of size equal to the number of indices
|
||||
// in the SparseIndexArray. So there is no need to adjust the size of the data
|
||||
// here. It is enough to just copy the incoming values into the data buffer.
|
||||
std::copy(values.begin(), values.end(), root_data.begin());
|
||||
@ -1094,14 +1080,14 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator,
|
||||
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
|
||||
TF_RET_CHECK(this_shape.element_type() ==
|
||||
primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
|
||||
absl::Span<NativeT> literal_data = data<NativeT>();
|
||||
if (rank > 0) {
|
||||
StrideConfig stride_config(this_shape, this_shape,
|
||||
AsInt64Slice(this_shape.dimensions()));
|
||||
int64 minor_dimension_size =
|
||||
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
|
||||
|
||||
auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
auto init_function = [&](absl::Span<const int64> indexes) {
|
||||
DimensionVector minor_scan_indexes(rank, 0);
|
||||
const int64 index =
|
||||
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
|
||||
@ -1119,7 +1105,7 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator,
|
||||
ShapeUtil::ForEachIndex(
|
||||
this_shape, stride_config.base, stride_config.dimensions,
|
||||
stride_config.step,
|
||||
[&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
[&init_function](absl::Span<const int64> indexes) {
|
||||
init_function(indexes);
|
||||
return true;
|
||||
});
|
||||
@ -1165,7 +1151,7 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
|
||||
}
|
||||
|
||||
DimensionVector output_indices(bounds.size(), 0);
|
||||
tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
|
||||
absl::Span<const int64> input_indices = output_indices;
|
||||
input_indices.remove_prefix(1);
|
||||
|
||||
bool done = false;
|
||||
|
@ -38,8 +38,8 @@ namespace {
|
||||
// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
|
||||
// -- on miscompare, a nice error message is given in the AssertionFailure.
|
||||
template <typename FloatT, typename UnsignedT>
|
||||
Status CompareFloatsBitwiseEqual(
|
||||
FloatT lhs, FloatT rhs, tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
|
||||
absl::Span<const int64> multi_index) {
|
||||
auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
|
||||
auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
|
||||
auto lhs_double = static_cast<double>(lhs);
|
||||
@ -60,7 +60,7 @@ Status CompareFloatsBitwiseEqual(
|
||||
// default gunit implementation).
|
||||
template <typename NativeT>
|
||||
Status CompareEqual(NativeT lhs, NativeT rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
absl::Span<const int64> multi_index) {
|
||||
if (lhs == rhs) {
|
||||
return Status::OK();
|
||||
}
|
||||
@ -74,28 +74,27 @@ Status CompareEqual(NativeT lhs, NativeT rhs,
|
||||
// comparison is requested.
|
||||
template <>
|
||||
Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
absl::Span<const int64> multi_index) {
|
||||
return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
|
||||
}
|
||||
template <>
|
||||
Status CompareEqual<Eigen::half>(
|
||||
Eigen::half lhs, Eigen::half rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
|
||||
absl::Span<const int64> multi_index) {
|
||||
return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
|
||||
}
|
||||
template <>
|
||||
Status CompareEqual<float>(float lhs, float rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
absl::Span<const int64> multi_index) {
|
||||
return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
|
||||
}
|
||||
template <>
|
||||
Status CompareEqual<double>(double lhs, double rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
absl::Span<const int64> multi_index) {
|
||||
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
|
||||
}
|
||||
template <>
|
||||
Status CompareEqual<complex64>(complex64 lhs, complex64 rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
absl::Span<const int64> multi_index) {
|
||||
auto res = CompareEqual<float>(lhs.real(), rhs.real(), multi_index);
|
||||
if (!res.ok()) {
|
||||
return res;
|
||||
@ -108,8 +107,7 @@ Status CompareEqual<complex64>(complex64 lhs, complex64 rhs,
|
||||
// elements are equal.
|
||||
template <typename NativeT>
|
||||
Status Equal(LiteralSlice expected, LiteralSlice actual,
|
||||
tensorflow::gtl::MutableArraySlice<int64> multi_index,
|
||||
int64 dimension) {
|
||||
absl::Span<int64> multi_index, int64 dimension) {
|
||||
if (dimension == expected.shape().dimensions_size()) {
|
||||
NativeT expected_value = expected.Get<NativeT>(multi_index);
|
||||
NativeT actual_value = actual.Get<NativeT>(multi_index);
|
||||
@ -305,8 +303,7 @@ class NearComparator {
|
||||
}
|
||||
|
||||
// Insert the given error into the given error bucket vector.
|
||||
void UpdateErrorBucket(
|
||||
float error, tensorflow::gtl::MutableArraySlice<int64> error_buckets) {
|
||||
void UpdateErrorBucket(float error, absl::Span<int64> error_buckets) {
|
||||
CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
|
||||
for (int i = 0; i < error_buckets.size(); ++i) {
|
||||
if (error >= kErrorBucketBounds[i]) {
|
||||
@ -410,10 +407,8 @@ class NearComparator {
|
||||
// Fast path optimization for the case were layouts match.
|
||||
if (LayoutUtil::Equal(actual_.shape().layout(),
|
||||
expected_.shape().layout())) {
|
||||
tensorflow::gtl::ArraySlice<const NativeT> expected_data =
|
||||
expected_.data<NativeT>();
|
||||
tensorflow::gtl::ArraySlice<const NativeT> actual_data =
|
||||
actual_.data<NativeT>();
|
||||
absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
|
||||
absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
|
||||
const int64 len = expected_data.size();
|
||||
for (int64 i = 0; i < len; ++i) {
|
||||
CompareValues(expected_data[i], actual_data[i], i);
|
||||
@ -488,7 +483,7 @@ class NearComparator {
|
||||
}
|
||||
|
||||
auto print_accum_buckets = [&](const string& header, int64 total,
|
||||
tensorflow::gtl::ArraySlice<int64> buckets) {
|
||||
absl::Span<const int64> buckets) {
|
||||
StrAppend(&out, header, ":\n");
|
||||
StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0],
|
||||
total - buckets[0],
|
||||
|
@ -36,7 +36,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using tensorflow::gtl::ArraySlice;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
@ -222,9 +221,9 @@ TEST_F(LiteralUtilTest, CreateSparse) {
|
||||
std::vector<int64> expected_values = {8, 9, 7, 10};
|
||||
|
||||
EXPECT_EQ(literal->sparse_indices()->data(),
|
||||
ArraySlice<int64>(expected_indices.data(),
|
||||
absl::Span<const int64>(expected_indices.data(),
|
||||
expected_indices.num_elements()));
|
||||
EXPECT_EQ(literal->data<int64>(), ArraySlice<int64>(expected_values));
|
||||
EXPECT_EQ(literal->data<int64>(), absl::Span<const int64>(expected_values));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
|
||||
@ -296,7 +295,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
|
||||
// clang-format on
|
||||
std::vector<std::tuple<int64, int64, string>> seen;
|
||||
literal->EachCellAsString(
|
||||
[&seen](ArraySlice<int64> indices, const string& value) {
|
||||
[&seen](absl::Span<const int64> indices, const string& value) {
|
||||
seen.emplace_back(indices[0], indices[1], value);
|
||||
});
|
||||
|
||||
@ -649,7 +648,7 @@ TEST_F(LiteralUtilTest, TransposeR4) {
|
||||
// clang-format on
|
||||
auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
|
||||
|
||||
reshape->EachCell<float>([&](ArraySlice<int64> indices, float value) {
|
||||
reshape->EachCell<float>([&](absl::Span<const int64> indices, float value) {
|
||||
EXPECT_EQ(value, original->Get<float>(
|
||||
{indices[2], indices[3], indices[0], indices[1]}));
|
||||
});
|
||||
@ -889,7 +888,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
|
||||
const int64 zero_base[] = {0, 0, 0, 0};
|
||||
const int64 step[] = {1, 1, 1, 1};
|
||||
uint32 seqnr = 0;
|
||||
auto init_proc = [&](ArraySlice<int64> indexes) {
|
||||
auto init_proc = [&](absl::Span<const int64> indexes) {
|
||||
source->Set(indexes, ++seqnr);
|
||||
return true;
|
||||
};
|
||||
@ -905,7 +904,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
|
||||
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
|
||||
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
|
||||
bool matched = true;
|
||||
auto check_proc = [&](ArraySlice<int64> indexes) {
|
||||
auto check_proc = [&](absl::Span<const int64> indexes) {
|
||||
std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
|
||||
std::transform(source_indexes.begin(), source_indexes.end(), src_base,
|
||||
source_indexes.begin(), std::plus<int64>());
|
||||
@ -1093,7 +1092,7 @@ TEST_F(LiteralUtilTest, Populate) {
|
||||
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
|
||||
data.layout);
|
||||
auto literal = absl::make_unique<Literal>(shape);
|
||||
auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
|
||||
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
|
||||
// Offsets from linear index just to avoid R0 literals to be initialized
|
||||
// with zero.
|
||||
return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
|
||||
@ -1105,7 +1104,7 @@ TEST_F(LiteralUtilTest, Populate) {
|
||||
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
||||
std::vector<int64> step(data.dimensions.size(), 1);
|
||||
bool matched = true;
|
||||
auto check_function = [&](ArraySlice<int64> indexes) {
|
||||
auto check_function = [&](absl::Span<const int64> indexes) {
|
||||
auto value = literal->Get<uint32>(indexes);
|
||||
matched = matched && (value == generator(indexes));
|
||||
return matched;
|
||||
@ -1135,7 +1134,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
|
||||
primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
|
||||
data.layout);
|
||||
auto literal = absl::make_unique<Literal>(shape);
|
||||
auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
|
||||
auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
|
||||
// Offsets from linear index just to avoid R0 literals to be initialized
|
||||
// with zero.
|
||||
return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
|
||||
@ -1147,7 +1146,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
|
||||
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
||||
std::vector<int64> step(data.dimensions.size(), 1);
|
||||
bool matched = true;
|
||||
auto check_function = [&](ArraySlice<int64> indexes) {
|
||||
auto check_function = [&](absl::Span<const int64> indexes) {
|
||||
auto value = literal->Get<uint32>(indexes);
|
||||
matched = matched && (value == generator(indexes));
|
||||
return matched;
|
||||
|
@ -84,8 +84,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
} // namespace
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
|
||||
PrimitiveType primitive_type,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
|
||||
return Literal::CreateFromShape(
|
||||
ShapeUtil::MakeShape(primitive_type, dimensions));
|
||||
}
|
||||
@ -301,9 +300,8 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
|
||||
tensorflow::gtl::ArraySlice<int64> new_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major,
|
||||
const LiteralSlice& literal) {
|
||||
absl::Span<const int64> new_dimensions,
|
||||
absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
|
||||
int64 new_num_elements = 1;
|
||||
for (int64 i = 0; i < new_dimensions.size(); ++i) {
|
||||
new_num_elements *= new_dimensions[i];
|
||||
@ -430,7 +428,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
|
||||
tensorflow::gtl::ArraySlice<const Literal*> elements) {
|
||||
absl::Span<const Literal* const> elements) {
|
||||
std::vector<Shape> element_shapes;
|
||||
for (const auto* element : elements) {
|
||||
element_shapes.push_back(element->shape());
|
||||
@ -444,7 +442,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
|
||||
tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
|
||||
absl::Span<const LiteralSlice> elements) {
|
||||
std::vector<Shape> element_shapes;
|
||||
for (const auto& element : elements) {
|
||||
element_shapes.push_back(element.shape());
|
||||
@ -474,7 +472,7 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
|
||||
}
|
||||
|
||||
/* static */ string LiteralUtil::MultiIndexAsString(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
absl::Span<const int64> multi_index) {
|
||||
return StrCat("{", absl::StrJoin(multi_index, ","), "}");
|
||||
}
|
||||
|
||||
|
@ -71,8 +71,7 @@ class LiteralUtil {
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR0(NativeT value);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR1(
|
||||
tensorflow::gtl::ArraySlice<NativeT> values);
|
||||
static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
|
||||
static std::unique_ptr<Literal> CreateR1(
|
||||
const tensorflow::core::Bitmap& values);
|
||||
template <typename NativeT>
|
||||
@ -141,8 +140,8 @@ class LiteralUtil {
|
||||
//
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateSparse(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
|
||||
absl::Span<const int64> dimensions, SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values, bool sort = true);
|
||||
|
||||
// Creates a scalar literal value zero of the given primitive type.
|
||||
static Literal Zero(PrimitiveType primitive_type);
|
||||
@ -157,7 +156,7 @@ class LiteralUtil {
|
||||
// Creates a literal of the given shape where each element is `value`.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
|
||||
absl::Span<const int64> dimensions, NativeT value);
|
||||
|
||||
// Creates a new literal from an Array type. The variants not ending with
|
||||
// WithLayout use the default XLA layout for the literal's linear
|
||||
@ -215,10 +214,10 @@ class LiteralUtil {
|
||||
// Returns a tuple literal composed of given literals. Data is copied from the
|
||||
// given elements into the returned literal.
|
||||
static std::unique_ptr<Literal> MakeTuple(
|
||||
tensorflow::gtl::ArraySlice<const Literal*> elements);
|
||||
absl::Span<const Literal* const> elements);
|
||||
|
||||
static std::unique_ptr<Literal> MakeTupleFromSlices(
|
||||
tensorflow::gtl::ArraySlice<LiteralSlice> elements);
|
||||
absl::Span<const LiteralSlice> elements);
|
||||
|
||||
// As above, but intended to be invoked with move semantics; i.e.
|
||||
//
|
||||
@ -259,8 +258,7 @@ class LiteralUtil {
|
||||
// The content of the literal values is the default value of the primitive
|
||||
// type of literal itself (0 for numeric types, and false for predicates).
|
||||
static std::unique_ptr<Literal> CreateFromDimensions(
|
||||
PrimitiveType primitive_type,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions);
|
||||
PrimitiveType primitive_type, absl::Span<const int64> dimensions);
|
||||
|
||||
// If the given literal's data type is bfloat16, converts it to a float
|
||||
// literal; otherwise, returns a copy of it. If the literal is a tuple,
|
||||
@ -279,9 +277,8 @@ class LiteralUtil {
|
||||
// buffer of the input literal is assumed to have the given minor_to_major
|
||||
// layout order.
|
||||
static std::unique_ptr<Literal> ReshapeSlice(
|
||||
tensorflow::gtl::ArraySlice<int64> new_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major,
|
||||
const LiteralSlice& literal);
|
||||
absl::Span<const int64> new_dimensions,
|
||||
absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
|
||||
|
||||
// Creates a literal with the supplied shape, and uses the provided value
|
||||
// generator to populate the literal's values.
|
||||
@ -291,7 +288,7 @@ class LiteralUtil {
|
||||
typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
|
||||
static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
|
||||
const Shape& shape,
|
||||
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
|
||||
const std::function<T(absl::Span<const int64>)>& generator);
|
||||
|
||||
// Creates a literal with the supplied shape, and initializes the literal
|
||||
// values using a normal distribution with given mean and stddev standard
|
||||
@ -319,8 +316,7 @@ class LiteralUtil {
|
||||
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
|
||||
// be returned for a 2-dimensional index with dimension 0 index equal to 7,
|
||||
// dimension 1 equal to 8.
|
||||
static string MultiIndexAsString(
|
||||
tensorflow::gtl::ArraySlice<int64> multi_index);
|
||||
static string MultiIndexAsString(absl::Span<const int64> multi_index);
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Literal& literal);
|
||||
@ -335,7 +331,7 @@ template <typename NativeT>
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
|
||||
tensorflow::gtl::ArraySlice<NativeT> values) {
|
||||
absl::Span<const NativeT> values) {
|
||||
auto literal = absl::make_unique<Literal>(
|
||||
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
|
||||
{static_cast<int64>(values.size())}));
|
||||
@ -427,8 +423,8 @@ template <typename NativeT>
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
|
||||
tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
|
||||
absl::Span<const int64> dimensions, SparseIndexArray indices,
|
||||
absl::Span<const NativeT> values, bool sort) {
|
||||
int64 num_elements = values.size();
|
||||
int64 rank = dimensions.size();
|
||||
CHECK_EQ(num_elements, indices.index_count());
|
||||
@ -570,8 +566,8 @@ template <typename NativeT>
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal>
|
||||
LiteralUtil::CreateFullWithDescendingLayout(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
|
||||
LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
|
||||
NativeT value) {
|
||||
auto literal =
|
||||
absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
|
||||
@ -583,14 +579,12 @@ template <PrimitiveType type, typename T>
|
||||
/* static */ StatusOr<std::unique_ptr<Literal>>
|
||||
LiteralUtil::CreateRandomLiteral(
|
||||
const Shape& shape,
|
||||
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
|
||||
const std::function<T(absl::Span<const int64>)>& generator) {
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
||||
TF_RET_CHECK(shape.element_type() == type);
|
||||
auto literal = absl::make_unique<Literal>(shape);
|
||||
TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indexes) {
|
||||
return generator(indexes);
|
||||
}));
|
||||
[&](absl::Span<const int64> indexes) { return generator(indexes); }));
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
@ -601,9 +595,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
|
||||
std::normal_distribution<NativeT> generator(mean, stddev);
|
||||
return CreateRandomLiteral<type, NativeT>(
|
||||
shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
|
||||
return generator(*engine);
|
||||
});
|
||||
shape,
|
||||
[&](absl::Span<const int64> /*indexes*/) { return generator(*engine); });
|
||||
}
|
||||
|
||||
template <PrimitiveType type, typename T>
|
||||
|
@ -61,7 +61,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
|
||||
result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
|
||||
|
||||
int64 elements = ShapeUtil::ElementsIn(shape);
|
||||
tensorflow::gtl::ArraySlice<float> field = result->data<float>();
|
||||
absl::Span<const float> field = result->data<float>();
|
||||
char* data = tensorflow::bit_cast<char*>(field.data());
|
||||
uint64 bytes = elements * sizeof(float);
|
||||
tensorflow::StringPiece sp; // non-absl OK
|
||||
|
@ -259,7 +259,7 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
||||
}
|
||||
|
||||
LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers(
|
||||
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles) {
|
||||
absl::Span<LocalShapedBuffer* const> argument_handles) {
|
||||
LocalClient* client = GetOrCreateLocalClient();
|
||||
|
||||
std::vector<const ShapedBuffer*> argument_buffers;
|
||||
@ -369,8 +369,7 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) {
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Broadcast(
|
||||
const LocalOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
|
||||
const LocalOp& operand, absl::Span<const int64> broadcast_sizes) {
|
||||
return xla::Broadcast(operand.op(), broadcast_sizes);
|
||||
}
|
||||
|
||||
@ -380,14 +379,14 @@ LocalOp LocalComputationBuilder::Pad(const LocalOp& operand,
|
||||
return xla::Pad(operand.op(), padding_value.op(), padding_config);
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Reshape(
|
||||
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> new_sizes) {
|
||||
LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand,
|
||||
absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> new_sizes) {
|
||||
return xla::Reshape(operand.op(), dimensions, new_sizes);
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Collapse(
|
||||
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand,
|
||||
absl::Span<const int64> dimensions) {
|
||||
return xla::Collapse(operand.op(), dimensions);
|
||||
}
|
||||
|
||||
@ -395,10 +394,10 @@ LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) {
|
||||
return xla::CrossReplicaSum(operand.op());
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Slice(
|
||||
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> strides) {
|
||||
LocalOp LocalComputationBuilder::Slice(const LocalOp& operand,
|
||||
absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides) {
|
||||
return xla::Slice(operand.op(), start_indices, limit_indices, strides);
|
||||
}
|
||||
|
||||
@ -411,7 +410,7 @@ LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand,
|
||||
|
||||
LocalOp LocalComputationBuilder::DynamicSlice(
|
||||
const LocalOp& operand, const LocalOp& start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
|
||||
absl::Span<const int64> slice_sizes) {
|
||||
return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
|
||||
}
|
||||
|
||||
@ -421,8 +420,8 @@ LocalOp LocalComputationBuilder::DynamicUpdateSlice(
|
||||
return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::ConcatInDim(
|
||||
tensorflow::gtl::ArraySlice<LocalOp> operands, int64 dimension) {
|
||||
LocalOp LocalComputationBuilder::ConcatInDim(absl::Span<const LocalOp> operands,
|
||||
int64 dimension) {
|
||||
std::vector<XlaOp> xla_ops;
|
||||
xla_ops.reserve(operands.size());
|
||||
for (const auto& op : operands) {
|
||||
@ -433,18 +432,16 @@ LocalOp LocalComputationBuilder::ConcatInDim(
|
||||
|
||||
LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
|
||||
const LocalOp& operand, const LocalComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
const LocalOp& source, const LocalOp& init_value,
|
||||
const LocalComputation& scatter) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding, const LocalOp& source,
|
||||
const LocalOp& init_value, const LocalComputation& scatter) {
|
||||
return xla::SelectAndScatterWithGeneralPadding(
|
||||
operand.op(), select.computation(), window_dimensions, window_strides,
|
||||
padding, source.op(), init_value.op(), scatter.computation());
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Tuple(
|
||||
tensorflow::gtl::ArraySlice<LocalOp> elements) {
|
||||
LocalOp LocalComputationBuilder::Tuple(absl::Span<const LocalOp> elements) {
|
||||
std::vector<XlaOp> xla_ops;
|
||||
xla_ops.reserve(elements.size());
|
||||
for (const auto& op : elements) {
|
||||
@ -471,10 +468,9 @@ LocalOp LocalComputationBuilder::DotGeneral(
|
||||
|
||||
LocalOp LocalComputationBuilder::ConvGeneralDilated(
|
||||
const LocalOp& lhs, const LocalOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers) {
|
||||
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers);
|
||||
@ -490,9 +486,8 @@ LocalOp LocalComputationBuilder::BitcastConvertType(
|
||||
return xla::BitcastConvertType(operand.op(), new_element_type);
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Call(
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<LocalOp> operands) {
|
||||
LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation,
|
||||
absl::Span<const LocalOp> operands) {
|
||||
std::vector<XlaOp> xla_ops;
|
||||
xla_ops.reserve(operands.size());
|
||||
for (const auto& op : operands) {
|
||||
@ -502,19 +497,18 @@ LocalOp LocalComputationBuilder::Call(
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Transpose(
|
||||
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> permutation) {
|
||||
const LocalOp& operand, absl::Span<const int64> permutation) {
|
||||
return xla::Transpose(operand.op(), permutation);
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Rev(
|
||||
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
LocalOp LocalComputationBuilder::Rev(const LocalOp& operand,
|
||||
absl::Span<const int64> dimensions) {
|
||||
return xla::Rev(operand.op(), dimensions);
|
||||
}
|
||||
|
||||
LocalOp LocalComputationBuilder::Map(
|
||||
tensorflow::gtl::ArraySlice<LocalOp> operands,
|
||||
LocalOp LocalComputationBuilder::Map(absl::Span<const LocalOp> operands,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
absl::Span<const int64> dimensions) {
|
||||
std::vector<XlaOp> xla_ops;
|
||||
xla_ops.reserve(operands.size());
|
||||
for (const auto& op : operands) {
|
||||
@ -528,7 +522,7 @@ LocalOp LocalComputationBuilder::Map(
|
||||
LocalOp LocalComputationBuilder::Reduce(
|
||||
const LocalOp& operand, const LocalOp& init_value,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
return xla::Reduce(operand.op(), init_value.op(),
|
||||
local_computation.computation(), dimensions_to_reduce);
|
||||
}
|
||||
@ -536,9 +530,9 @@ LocalOp LocalComputationBuilder::Reduce(
|
||||
LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
|
||||
const LocalOp& operand, const LocalOp& init_value,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding) {
|
||||
return xla::ReduceWindowWithGeneralPadding(
|
||||
operand.op(), init_value.op(), local_computation.computation(),
|
||||
window_dimensions, window_strides, padding);
|
||||
@ -602,7 +596,7 @@ StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
|
||||
#define _FORWARD_BINOP(method_name) \
|
||||
_FORWARD(method_name, LocalOp, \
|
||||
(const LocalOp& lhs, const LocalOp& rhs, \
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
|
||||
absl::Span<const int64> broadcast_dimensions), \
|
||||
(lhs.op(), rhs.op(), broadcast_dimensions))
|
||||
|
||||
#define _FORWARD_TRIOP(method_name) \
|
||||
|
@ -122,7 +122,7 @@ class CompiledLocalComputation {
|
||||
const std::vector<absl::optional<Shape> >& shapes_with_layout);
|
||||
|
||||
LocalShapedBuffer* ExecuteWithShapedBuffers(
|
||||
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
|
||||
absl::Span<LocalShapedBuffer* const> argument_handles);
|
||||
|
||||
private:
|
||||
std::unique_ptr<LocalExecutable> executable_;
|
||||
@ -199,46 +199,41 @@ class LocalComputationBuilder {
|
||||
LocalOp ConstantLiteral(const Literal& literal);
|
||||
|
||||
LocalOp Broadcast(const LocalOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
|
||||
absl::Span<const int64> broadcast_sizes);
|
||||
|
||||
LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value,
|
||||
const PaddingConfig& padding_config);
|
||||
|
||||
LocalOp Reshape(const LocalOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> new_sizes);
|
||||
LocalOp Reshape(const LocalOp& operand, absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> new_sizes);
|
||||
|
||||
LocalOp Collapse(const LocalOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions);
|
||||
LocalOp Collapse(const LocalOp& operand, absl::Span<const int64> dimensions);
|
||||
|
||||
LocalOp CrossReplicaSum(const LocalOp& operand);
|
||||
|
||||
LocalOp Slice(const LocalOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> limit_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> strides);
|
||||
LocalOp Slice(const LocalOp& operand, absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides);
|
||||
|
||||
LocalOp SliceInDim(const LocalOp& operand, int64 start_index,
|
||||
int64 limit_index, int64 stride, int64 dimno);
|
||||
|
||||
LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices,
|
||||
tensorflow::gtl::ArraySlice<int64> slice_sizes);
|
||||
absl::Span<const int64> slice_sizes);
|
||||
|
||||
LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update,
|
||||
const LocalOp& start_indices);
|
||||
|
||||
LocalOp ConcatInDim(tensorflow::gtl::ArraySlice<LocalOp> operands,
|
||||
int64 dimension);
|
||||
LocalOp ConcatInDim(absl::Span<const LocalOp> operands, int64 dimension);
|
||||
|
||||
LocalOp SelectAndScatterWithGeneralPadding(
|
||||
const LocalOp& operand, const LocalComputation& select,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
|
||||
const LocalOp& source, const LocalOp& init_value,
|
||||
const LocalComputation& scatter);
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64> > padding, const LocalOp& source,
|
||||
const LocalOp& init_value, const LocalComputation& scatter);
|
||||
|
||||
LocalOp Tuple(tensorflow::gtl::ArraySlice<LocalOp> elements);
|
||||
LocalOp Tuple(absl::Span<const LocalOp> elements);
|
||||
|
||||
LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index);
|
||||
|
||||
@ -249,10 +244,10 @@ class LocalComputationBuilder {
|
||||
|
||||
LocalOp ConvGeneralDilated(
|
||||
const LocalOp& lhs, const LocalOp& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
|
||||
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64> > padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers);
|
||||
|
||||
LocalOp ConvertElementType(const LocalOp& operand,
|
||||
@ -262,28 +257,27 @@ class LocalComputationBuilder {
|
||||
PrimitiveType new_element_type);
|
||||
|
||||
LocalOp Call(const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<LocalOp> operands);
|
||||
absl::Span<const LocalOp> operands);
|
||||
|
||||
LocalOp Transpose(const LocalOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> permutation);
|
||||
absl::Span<const int64> permutation);
|
||||
|
||||
LocalOp Rev(const LocalOp& operand,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions);
|
||||
LocalOp Rev(const LocalOp& operand, absl::Span<const int64> dimensions);
|
||||
|
||||
LocalOp Map(tensorflow::gtl::ArraySlice<LocalOp> operands,
|
||||
LocalOp Map(absl::Span<const LocalOp> operands,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions);
|
||||
absl::Span<const int64> dimensions);
|
||||
|
||||
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
|
||||
absl::Span<const int64> dimensions_to_reduce);
|
||||
|
||||
LocalOp ReduceWindowWithGeneralPadding(
|
||||
const LocalOp& operand, const LocalOp& init_value,
|
||||
const LocalComputation& local_computation,
|
||||
tensorflow::gtl::ArraySlice<int64> window_dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding);
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64> > padding);
|
||||
|
||||
LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
|
||||
const Shape& shape);
|
||||
@ -316,7 +310,7 @@ class LocalComputationBuilder {
|
||||
#define _FORWARD_BINOP(method_name) \
|
||||
_FORWARD(method_name, LocalOp, \
|
||||
(const LocalOp& lhs, const LocalOp& rhs, \
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
|
||||
absl::Span<const int64> broadcast_dimensions))
|
||||
|
||||
#define _FORWARD_TRIOP(method_name) \
|
||||
_FORWARD(method_name, LocalOp, \
|
||||
|
@ -22,15 +22,15 @@ limitations under the License.
|
||||
//
|
||||
// C++ Python
|
||||
// -------------------------------------+---------------------------------------
|
||||
// ArraySlice<int64> <- sequence of int
|
||||
// ArraySlice<LocalOp> <- sequence of LocalOp
|
||||
// Span<int64> <- sequence of int
|
||||
// Span<LocalOp> <- sequence of LocalOp
|
||||
// Literal <-> (nested tuple of) numpy ndarray
|
||||
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
|
||||
// Shape -> pair holding (dtype, dimensions)
|
||||
// <- object duck-typed as xla_client.Shape
|
||||
// std::vector<Shape> <- sequence of xla_client.Shape objects
|
||||
// PrimitiveType <- int
|
||||
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
|
||||
// Span<pair<int64, in64>> <- sequence of int pairs
|
||||
// PaddingConfig proto <- corresponding Python proto
|
||||
// ConvolutionDimensionNumbers proto <- corresponding Python proto
|
||||
// DotDimensionNumbers proto <- corresponding Python proto
|
||||
@ -267,9 +267,9 @@ tensorflow::ImportNumpy();
|
||||
$result = Py_None;
|
||||
}
|
||||
|
||||
// ArraySlice<int64>
|
||||
// Span<int64>
|
||||
|
||||
%typemap(in) tensorflow::gtl::ArraySlice<int64>
|
||||
%typemap(in) absl::Span<const int64>
|
||||
(std::vector<int64> temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
@ -299,9 +299,9 @@ tensorflow::ImportNumpy();
|
||||
$1 = temps;
|
||||
}
|
||||
|
||||
// ArraySlice<LocalOp>
|
||||
// Span<LocalOp>
|
||||
|
||||
%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalOp>(
|
||||
%typemap(in) absl::Span<const xla::swig::LocalOp>(
|
||||
std::vector<LocalOp> temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
@ -323,7 +323,7 @@ tensorflow::ImportNumpy();
|
||||
|
||||
// LocalShapedBuffer*
|
||||
|
||||
%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalShapedBuffer*>
|
||||
%typemap(in) absl::Span<xla::swig::LocalShapedBuffer* const>
|
||||
(std::vector<LocalShapedBuffer*> temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
@ -496,9 +496,9 @@ tensorflow::ImportNumpy();
|
||||
$1 = static_cast<PrimitiveType>(value);
|
||||
}
|
||||
|
||||
// ArraySlice<pair<int64, in64>>
|
||||
// Span<pair<int64, in64>>
|
||||
|
||||
%typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> >
|
||||
%typemap(in) absl::Span<const std::pair<int64, int64> >
|
||||
(std::vector<std::pair<int64, int64> > temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
|
@ -108,14 +108,12 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
|
||||
// array by adding a fourth dummy dimension of size 1 without stride, padding
|
||||
// and dilation.
|
||||
Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
|
||||
a4dlhs.Each(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
|
||||
a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
|
||||
CHECK_EQ(indices[3], 0);
|
||||
*value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
|
||||
});
|
||||
Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
|
||||
a4drhs.Each(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
|
||||
a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
|
||||
CHECK_EQ(indices[3], 0);
|
||||
*value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
|
||||
});
|
||||
@ -130,8 +128,7 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
|
||||
|
||||
auto convr3 = absl::make_unique<Array3D<float>>(
|
||||
convr4->planes(), convr4->depth(), convr4->height());
|
||||
convr4->Each(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
|
||||
convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
|
||||
CHECK_EQ(indices[3], 0);
|
||||
convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
|
||||
});
|
||||
@ -189,11 +186,11 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
|
||||
|
||||
/* static */ std::unique_ptr<std::vector<float>>
|
||||
ReferenceUtil::ReduceWindow1DGeneric(
|
||||
const tensorflow::gtl::ArraySlice<float>& operand, float init,
|
||||
const absl::Span<const float>& operand, float init,
|
||||
const std::function<float(float, float)>& reduce_func,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride,
|
||||
const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride,
|
||||
const absl::Span<const std::pair<int64, int64>>& padding) {
|
||||
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
|
||||
std::vector<int64> window_counts(window.size(), 0);
|
||||
std::vector<int64> pad_low(window.size(), 0);
|
||||
@ -221,10 +218,11 @@ ReferenceUtil::ReduceWindow1DGeneric(
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<std::vector<float>>
|
||||
ReferenceUtil::ReduceWindow1DAdd(
|
||||
const tensorflow::gtl::ArraySlice<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
|
||||
ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
|
||||
float init,
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride,
|
||||
Padding padding) {
|
||||
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
|
||||
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
|
||||
return ReduceWindow1DGeneric(
|
||||
@ -236,9 +234,9 @@ ReferenceUtil::ReduceWindow1DAdd(
|
||||
ReferenceUtil::ReduceWindow2DGeneric(
|
||||
const Array2D<float>& operand, float init,
|
||||
const std::function<float(float, float)>& reduce_func,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride,
|
||||
const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride,
|
||||
const absl::Span<const std::pair<int64, int64>>& padding) {
|
||||
std::vector<int64> dim_lengths{operand.height(), operand.width()};
|
||||
|
||||
std::vector<int64> window_counts(window.size(), 0);
|
||||
@ -276,8 +274,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
|
||||
|
||||
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
|
||||
const Array2D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding) {
|
||||
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
|
||||
std::vector<int64> dim_lengths{operand.height(), operand.width()};
|
||||
return ReduceWindow2DGeneric(
|
||||
@ -287,8 +285,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
|
||||
|
||||
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
|
||||
const Array3D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding) {
|
||||
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
|
||||
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
|
||||
|
||||
@ -334,8 +332,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
|
||||
ReferenceUtil::ReduceWindow4DGeneric(
|
||||
const Array4D<float>& operand, float init,
|
||||
const std::function<float(float, float)>& reduce_func,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding) {
|
||||
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
|
||||
operand.n4()};
|
||||
return ReduceWindow4DGeneric(
|
||||
@ -347,9 +345,9 @@ ReferenceUtil::ReduceWindow4DGeneric(
|
||||
ReferenceUtil::ReduceWindow4DGeneric(
|
||||
const Array4D<float>& operand, float init,
|
||||
const std::function<float(float, float)>& reduce_func,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride,
|
||||
const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride,
|
||||
const absl::Span<const std::pair<int64, int64>>& padding) {
|
||||
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
|
||||
operand.n4()};
|
||||
|
||||
@ -402,8 +400,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
|
||||
|
||||
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
|
||||
const Array4D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding) {
|
||||
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
|
||||
return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
|
||||
padding);
|
||||
@ -424,10 +422,12 @@ ReferenceUtil::ReduceWindow4DGeneric(
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array4D<float>>
|
||||
ReferenceUtil::SelectAndScatter4DGePlus(
|
||||
const Array4D<float>& operand, const Array4D<float>& source, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
|
||||
ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
|
||||
const Array4D<float>& source,
|
||||
float init,
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride,
|
||||
bool same_padding) {
|
||||
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
|
||||
auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
|
||||
operand.n3(), operand.n4());
|
||||
@ -591,7 +591,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
result_literal->shape().dimensions(2),
|
||||
result_literal->shape().dimensions(3));
|
||||
|
||||
result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
|
||||
result->Each([&](absl::Span<const int64> indices, float* value) {
|
||||
*value = result_literal->Get<float>(indices);
|
||||
});
|
||||
|
||||
@ -633,8 +633,7 @@ ReferenceUtil::ReduceToRowArray2D(
|
||||
}
|
||||
|
||||
/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
|
||||
const Array4D<float>& array, float init,
|
||||
tensorflow::gtl::ArraySlice<int64> dims,
|
||||
const Array4D<float>& array, float init, absl::Span<const int64> dims,
|
||||
const std::function<float(float, float)>& reduce_function) {
|
||||
std::vector<float> result;
|
||||
CHECK_EQ(dims.size(), 3);
|
||||
@ -707,8 +706,7 @@ ReferenceUtil::ReduceToRowArray2D(
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
|
||||
const Array3D<float>& array, float init,
|
||||
tensorflow::gtl::ArraySlice<int64> dims,
|
||||
const Array3D<float>& array, float init, absl::Span<const int64> dims,
|
||||
const std::function<float(float, float)>& reduce_function) {
|
||||
CHECK_EQ(dims.size(), 1);
|
||||
int64 rows = dims[0] == 0 ? array.n2() : array.n1();
|
||||
|
@ -144,8 +144,7 @@ class ReferenceUtil {
|
||||
// Returns the result of reducing the 4D array to a vector, reducing away
|
||||
// the dimensions specified in dims.
|
||||
static std::vector<float> Reduce4DTo1D(
|
||||
const Array4D<float>& array, float init,
|
||||
tensorflow::gtl::ArraySlice<int64> dims,
|
||||
const Array4D<float>& array, float init, absl::Span<const int64> dims,
|
||||
const std::function<float(float, float)>& reduce_function);
|
||||
|
||||
// Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
|
||||
@ -156,8 +155,7 @@ class ReferenceUtil {
|
||||
// Returns the result of reducing the 3D array to a 2D array, reducing away
|
||||
// the dimensions specified in dims.
|
||||
static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
|
||||
const Array3D<float>& array, float init,
|
||||
tensorflow::gtl::ArraySlice<int64> dims,
|
||||
const Array3D<float>& array, float init, absl::Span<const int64> dims,
|
||||
const std::function<float(float, float)>& reduce_function);
|
||||
|
||||
// Applies map_function to each element in the input (2D array) and returns
|
||||
@ -179,47 +177,47 @@ class ReferenceUtil {
|
||||
|
||||
// Windowed reductions with Add as the function to apply.
|
||||
static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
|
||||
const tensorflow::gtl::ArraySlice<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
|
||||
const absl::Span<const float>& operand, float init,
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding);
|
||||
static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
|
||||
const Array2D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding);
|
||||
static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
|
||||
const Array3D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding);
|
||||
static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
|
||||
const Array4D<float>& operand, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding);
|
||||
|
||||
// Windowed reductions with a generic reduce function.
|
||||
static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
|
||||
const tensorflow::gtl::ArraySlice<float>& operand, float init,
|
||||
const absl::Span<const float>& operand, float init,
|
||||
const std::function<float(float, float)>& reduce_func,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride,
|
||||
const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride,
|
||||
const absl::Span<const std::pair<int64, int64>>& padding);
|
||||
static std::unique_ptr<Array2D<float>> ReduceWindow2DGeneric(
|
||||
const Array2D<float>& operand, float init,
|
||||
const std::function<float(float, float)>& reduce_func,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride,
|
||||
const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride,
|
||||
const absl::Span<const std::pair<int64, int64>>& padding);
|
||||
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
|
||||
const Array4D<float>& operand, float init,
|
||||
const std::function<float(float, float)>& reduce_func,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, Padding padding);
|
||||
// With arbitrary padding.
|
||||
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
|
||||
const Array4D<float>& operand, float init,
|
||||
const std::function<float(float, float)>& reduce_func,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride,
|
||||
const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride,
|
||||
const absl::Span<const std::pair<int64, int64>>& padding);
|
||||
|
||||
// Batch normalize data.
|
||||
static std::unique_ptr<Array4D<float>> BatchNorm4D(
|
||||
@ -232,8 +230,8 @@ class ReferenceUtil {
|
||||
// TODO(b/74533103) Switch tests to evaluator and remove this implementation.
|
||||
static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
|
||||
const Array4D<float>& operand, const Array4D<float>& source, float init,
|
||||
const tensorflow::gtl::ArraySlice<int64>& window,
|
||||
const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding);
|
||||
const absl::Span<const int64>& window,
|
||||
const absl::Span<const int64>& stride, bool same_padding);
|
||||
|
||||
// Concatenates the lhs and rhs arrays along the concatenate_dimension.
|
||||
// E.g. if concatenate_dimension is 0, the "n1"/height dimension is
|
||||
@ -334,8 +332,8 @@ class ReferenceUtil {
|
||||
|
||||
// Slices with index clamping
|
||||
template <typename T>
|
||||
static std::vector<T> ClampSlice1D(
|
||||
const tensorflow::gtl::ArraySlice<T>& input, int64 start, int64 size) {
|
||||
static std::vector<T> ClampSlice1D(const absl::Span<const T>& input,
|
||||
int64 start, int64 size) {
|
||||
start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
|
||||
std::vector<T> result;
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
@ -633,7 +631,7 @@ class ReferenceUtil {
|
||||
Array4D<NativeT> result(output_bounds[0], output_bounds[1],
|
||||
output_bounds[2], output_bounds[3]);
|
||||
result.Each(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> indices, NativeT* value) {
|
||||
[&](absl::Span<const int64> indices, NativeT* value) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
bool in_low_padding = indices[i] < pad_low[i];
|
||||
bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
|
||||
|
@ -449,8 +449,7 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleConcatenate(
|
||||
HloInstruction* concatenate) {
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands(
|
||||
concatenate->operands());
|
||||
absl::Span<HloInstruction* const> operands(concatenate->operands());
|
||||
if (operands.size() == 1) {
|
||||
// Unary concatenates are useless.
|
||||
ReplaceInstructionIfSameShape(concatenate, operands[0]);
|
||||
@ -588,7 +587,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
|
||||
namespace {
|
||||
template <typename T>
|
||||
Status InvertConstant(const HloInstruction& constant, Literal* result) {
|
||||
return result->Populate<T>([&](tensorflow::gtl::ArraySlice<int64> indices) {
|
||||
return result->Populate<T>([&](absl::Span<const int64> indices) {
|
||||
return T{1.0} / constant.literal().Get<T>(indices);
|
||||
});
|
||||
}
|
||||
@ -1249,8 +1248,7 @@ namespace {
|
||||
//
|
||||
// Precondition: input_dim_indices is sorted.
|
||||
absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
|
||||
const HloInstruction* hlo,
|
||||
tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
|
||||
const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
|
||||
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
|
||||
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
|
||||
|
||||
@ -1853,7 +1851,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
|
||||
auto arg = reduce->mutable_operand(0);
|
||||
auto init_value = reduce->mutable_operand(1);
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
|
||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||
HloComputation* function = reduce->to_apply();
|
||||
if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
|
||||
ShapeUtil::IsZeroElementArray(reduce->shape())) {
|
||||
|
@ -2226,7 +2226,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
|
||||
auto out_dims = in_dims;
|
||||
out_dims[in_channel_idx] = options.f_output_channels;
|
||||
|
||||
auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims,
|
||||
auto make_shape = [](absl::Span<const int64> dims,
|
||||
bool minor_to_major_layout) {
|
||||
if (minor_to_major_layout) {
|
||||
return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
|
||||
@ -2838,8 +2838,8 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
|
||||
|
||||
// a and b are parallel bounds we can either turn into a B F S0 S1 or
|
||||
// `B S0 S1 F` kind of pattern.
|
||||
auto decorate_spatials = [¶m](tensorflow::gtl::ArraySlice<int64> spatials,
|
||||
int64 a, int64 b) {
|
||||
auto decorate_spatials = [¶m](absl::Span<const int64> spatials, int64 a,
|
||||
int64 b) {
|
||||
std::vector<int64> result;
|
||||
if (param.prepend_a) {
|
||||
result.push_back(a);
|
||||
|
@ -112,10 +112,10 @@ StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
|
||||
return stream_pools_.at(executor).BorrowStream(executor);
|
||||
}
|
||||
|
||||
Backend::Backend(
|
||||
se::Platform* platform, Compiler* compiler,
|
||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
|
||||
TransferManager* transfer_manager, ComputationPlacer* computation_placer,
|
||||
Backend::Backend(se::Platform* platform, Compiler* compiler,
|
||||
absl::Span<se::StreamExecutor* const> stream_executors,
|
||||
TransferManager* transfer_manager,
|
||||
ComputationPlacer* computation_placer,
|
||||
int intra_op_parallelism_threads)
|
||||
: platform_(platform),
|
||||
compiler_(compiler),
|
||||
|
@ -149,7 +149,7 @@ class Backend {
|
||||
private:
|
||||
struct EigenThreadPoolWrapper;
|
||||
Backend(se::Platform* platform, Compiler* compiler,
|
||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
|
||||
absl::Span<se::StreamExecutor* const> stream_executors,
|
||||
TransferManager* transfer_manager,
|
||||
ComputationPlacer* computation_placer,
|
||||
int intra_op_parallelism_threads);
|
||||
|
@ -69,8 +69,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
|
||||
// Inserts conversion HLOs to replace the called computations' BF16
|
||||
// operands/outputs to F32.
|
||||
Status ConvertCalledComputations(
|
||||
HloInstruction* hlo,
|
||||
tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps);
|
||||
HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps);
|
||||
|
||||
HloComputation* computation_;
|
||||
const BFloat16Support* bfloat16_support_;
|
||||
@ -114,8 +113,7 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::ConvertCalledComputations(
|
||||
HloInstruction* hlo,
|
||||
tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps) {
|
||||
HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps) {
|
||||
std::map<HloComputation*, HloComputation*> cloned_computations;
|
||||
for (auto& comp : bf16_called_comps) {
|
||||
auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
|
||||
|
@ -407,7 +407,7 @@ void BFloat16Propagation::AdjustCalledComputationParameters(
|
||||
HloInstruction* hlo) {
|
||||
auto adjust_computation =
|
||||
[this, hlo](HloComputation* computation,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
||||
absl::Span<HloInstruction* const> operands) {
|
||||
// Adjust parameters.
|
||||
CHECK_EQ(operands.size(), computation->num_parameters());
|
||||
for (int64 i = 0; i < operands.size(); ++i) {
|
||||
|
@ -118,7 +118,7 @@ class BufferAssignmentTest : public HloVerifiedTestBase {
|
||||
|
||||
std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
|
||||
HloModule* module,
|
||||
tensorflow::gtl::ArraySlice<const HloInstruction*> instruction_sequence,
|
||||
absl::Span<const HloInstruction* const> instruction_sequence,
|
||||
int64 alignment = 1) {
|
||||
SequentialHloOrdering::HloModuleSequence module_sequence;
|
||||
module_sequence[module->entry_computation()] =
|
||||
|
@ -62,7 +62,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyService::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const absl::Span<const AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options,
|
||||
std::unique_ptr<AotCompilationMetadata>* metadata) {
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
||||
|
@ -50,12 +50,12 @@ class CompileOnlyService : public Service {
|
||||
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const absl::Span<const AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options);
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const absl::Span<const AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options,
|
||||
std::unique_ptr<AotCompilationMetadata>* metadata);
|
||||
|
||||
|
@ -479,7 +479,7 @@ class CopyRemover {
|
||||
// 'values' an entry is created in value_to_node which indicates the
|
||||
// respective ValueNode representing that value.
|
||||
void AddValueList(
|
||||
tensorflow::gtl::ArraySlice<const HloValue*> values,
|
||||
absl::Span<const HloValue* const> values,
|
||||
tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
|
||||
ValueNode* tail = nullptr;
|
||||
ValueNode* head = nullptr;
|
||||
|
@ -40,7 +40,7 @@ std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
|
||||
}
|
||||
|
||||
std::vector<int32> CreateArgIndexTableFromBufferInfos(
|
||||
tensorflow::gtl::ArraySlice<BufferInfo> buffer_infos) {
|
||||
absl::Span<const BufferInfo> buffer_infos) {
|
||||
std::vector<int32> result;
|
||||
for (int64 i = 0; i < buffer_infos.size(); i++) {
|
||||
if (buffer_infos[i].is_entry_parameter()) {
|
||||
|
@ -34,7 +34,7 @@ CreateBufferInfosFromBufferAssignment(
|
||||
// If this function returns V then entry parameter i has buffer allocation index
|
||||
// V[i].
|
||||
std::vector<int32> CreateArgIndexTableFromBufferInfos(
|
||||
tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo>
|
||||
absl::Span<const ::tensorflow::cpu_function_runtime::BufferInfo>
|
||||
buffer_infos);
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -77,7 +77,7 @@ StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
|
||||
std::vector<OwningDeviceMemory>>>
|
||||
CpuExecutable::CreateTempArray(
|
||||
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
|
||||
absl::Span<const ShapedBuffer* const> arguments) {
|
||||
std::vector<se::DeviceMemoryBase> unowning_buffers(
|
||||
assignment_->Allocations().size());
|
||||
std::vector<OwningDeviceMemory> owning_buffers(
|
||||
@ -136,7 +136,7 @@ CpuExecutable::CreateTempArray(
|
||||
|
||||
Status CpuExecutable::ExecuteComputeFunction(
|
||||
const ExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
|
||||
absl::Span<const se::DeviceMemoryBase> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
// The calling convention for JITed functions is:
|
||||
//
|
||||
@ -207,7 +207,7 @@ Status CpuExecutable::ExecuteComputeFunction(
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
|
||||
absl::Span<OwningDeviceMemory> buffers) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
ScopedShapedBuffer result_buffer(
|
||||
/*on_host_shape=*/result_shape(),
|
||||
@ -245,7 +245,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto result,
|
||||
@ -256,7 +256,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
|
||||
absl::Span<const ShapedBuffer* const> arguments) {
|
||||
if (hlo_profiling_enabled()) {
|
||||
return Unimplemented(
|
||||
"Asynchronous execution on stream with hlo profiling is not yet "
|
||||
@ -267,7 +267,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
if (GetRootPointsToSet().IsAmbiguous()) {
|
||||
return Unimplemented("Points-to set of root instruction is ambiguous");
|
||||
@ -299,7 +299,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
|
||||
//
|
||||
// We also need to change the types of some of the variables we capture:
|
||||
// run_options needs to change from a pointer to a value type, and arguments
|
||||
// needs to change from an ArraySlice into a vector. We use a struct instead
|
||||
// needs to change from a Span into a vector. We use a struct instead
|
||||
// of a lambda to make this explicit.
|
||||
struct AsyncRunTask {
|
||||
CpuExecutable* executable;
|
||||
|
@ -57,12 +57,12 @@ class CpuExecutable : public Executable {
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override;
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
|
||||
absl::Span<const ShapedBuffer* const> arguments) override;
|
||||
|
||||
// This should be called after set_ir_module_string.
|
||||
const string& ir_module_string() const { return ir_module_string_; }
|
||||
@ -92,7 +92,7 @@ class CpuExecutable : public Executable {
|
||||
// exists) must out-live the task.
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
// Creates an array suitable for passing as the "temps" argument to the JIT
|
||||
@ -112,13 +112,12 @@ class CpuExecutable : public Executable {
|
||||
StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
|
||||
std::vector<OwningDeviceMemory>>>
|
||||
CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
|
||||
absl::Span<const ShapedBuffer* const> arguments);
|
||||
|
||||
// Calls the generated function performing the computation with the given
|
||||
// arguments using the supplied buffers.
|
||||
Status ExecuteComputeFunction(
|
||||
const ExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
|
||||
Status ExecuteComputeFunction(const ExecutableRunOptions* run_options,
|
||||
absl::Span<const se::DeviceMemoryBase> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
// Creates a ScopedShapedBuffer for holding the result of the computation,
|
||||
@ -126,7 +125,7 @@ class CpuExecutable : public Executable {
|
||||
// The addresses are set according to buffer assignment.
|
||||
StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers);
|
||||
absl::Span<OwningDeviceMemory> buffers);
|
||||
|
||||
// Returns the points-to set of the root instruction of the entry
|
||||
// computation. Uses points-to analysis from buffer assignment.
|
||||
|
@ -179,7 +179,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
int64 size = GetByteSizeRequirement(literal_shape);
|
||||
// Note: OSS build didn't like implicit conversion from
|
||||
// literal_shape.dimensions() to the array slice on 2017-07-10.
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions(
|
||||
absl::Span<const int64> dimensions(
|
||||
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
|
||||
literal_shape.dimensions().size());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -225,7 +225,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
|
||||
StatusOr<Shape> CpuTransferManager::TransferTupleBuffersFromOutfeed(
|
||||
se::StreamExecutor* executor,
|
||||
tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data) {
|
||||
absl::Span<const std::pair<void*, int64>> buffer_data) {
|
||||
return TransferBuffersFromOutfeedInternal(executor, buffer_data,
|
||||
/*is_tuple=*/true);
|
||||
}
|
||||
@ -238,8 +238,7 @@ StatusOr<Shape> CpuTransferManager::TransferArrayBufferFromOutfeed(
|
||||
|
||||
StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
|
||||
se::StreamExecutor* executor,
|
||||
tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data,
|
||||
bool is_tuple) {
|
||||
absl::Span<const std::pair<void*, int64>> buffer_data, bool is_tuple) {
|
||||
std::vector<std::unique_ptr<CpuOutfeedBuffer>> buffers;
|
||||
for (auto b : buffer_data) {
|
||||
int64 size = b.second;
|
||||
|
@ -56,7 +56,7 @@ class CpuTransferManager : public GenericTransferManager {
|
||||
// Helper that transfers a tuple of element buffers from the device's outfeed.
|
||||
StatusOr<Shape> TransferTupleBuffersFromOutfeed(
|
||||
se::StreamExecutor* executor,
|
||||
tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data);
|
||||
absl::Span<const std::pair<void*, int64>> buffer_data);
|
||||
|
||||
// Helper that transfers an array buffer from the device's outfeed.
|
||||
StatusOr<Shape> TransferArrayBufferFromOutfeed(se::StreamExecutor* executor,
|
||||
@ -68,8 +68,7 @@ class CpuTransferManager : public GenericTransferManager {
|
||||
// for the given buffers.
|
||||
StatusOr<Shape> TransferBuffersFromOutfeedInternal(
|
||||
se::StreamExecutor* executor,
|
||||
tensorflow::gtl::ArraySlice<std::pair<void*, int64>> buffer_data,
|
||||
bool is_tuple);
|
||||
absl::Span<const std::pair<void*, int64>> buffer_data, bool is_tuple);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager);
|
||||
};
|
||||
|
@ -80,7 +80,7 @@ class MemoryTile {
|
||||
// `minor_dim_offset`}.
|
||||
//
|
||||
// Note: `major_dim_offset` is a parameter to the constructor.
|
||||
void StoreTile(tensorflow::gtl::ArraySlice<llvm::Value*> tile,
|
||||
void StoreTile(absl::Span<llvm::Value* const> tile,
|
||||
llvm::Value* minor_dim_offset) const {
|
||||
CHECK_EQ(tile.size(), pointers_.size());
|
||||
for (int64 i = 0; i < pointers_.size(); i++) {
|
||||
|
@ -506,8 +506,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||
|
||||
llvm::Value* IrEmitter::EmitElementalMap(
|
||||
const HloMapInstruction& map_instr,
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
|
||||
absl::string_view name) {
|
||||
absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
|
||||
return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
|
||||
}
|
||||
|
||||
@ -1455,7 +1454,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
|
||||
const ReductionGenerator& reduction_generator,
|
||||
const llvm_ir::IrArray::Index& output_index,
|
||||
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
|
||||
HloInstruction* arg, gtl::ArraySlice<int64> dimensions,
|
||||
HloInstruction* arg, absl::Span<const int64> dimensions,
|
||||
unsigned element_alignment) {
|
||||
ShardedVector accumulator;
|
||||
accumulator.reserve(accumulator_type.size());
|
||||
@ -1551,7 +1550,7 @@ void IrEmitter::EmitShardedVectorStore(
|
||||
|
||||
StatusOr<bool> IrEmitter::EmitVectorizedReduce(
|
||||
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
|
||||
gtl::ArraySlice<int64> dimensions, HloComputation* function,
|
||||
absl::Span<const int64> dimensions, HloComputation* function,
|
||||
string* failure_reason) {
|
||||
if (!ReductionPreservesLayout(*reduce)) {
|
||||
return false;
|
||||
@ -1701,7 +1700,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
|
||||
HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) {
|
||||
const HloInstruction* arg = reduce->mutable_operand(0);
|
||||
const HloInstruction* init_value = reduce->mutable_operand(1);
|
||||
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
|
||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||
|
||||
// Initialize an accumulator with init_value.
|
||||
PrimitiveType accumulator_type = reduce->shape().element_type();
|
||||
@ -1758,7 +1757,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
||||
}
|
||||
auto arg = reduce->mutable_operand(0);
|
||||
auto init_value = reduce->mutable_operand(1);
|
||||
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
|
||||
absl::Span<const int64> dimensions(reduce->dimensions());
|
||||
HloComputation* function = reduce->to_apply();
|
||||
if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
|
||||
string vectorization_failure_reason;
|
||||
@ -2113,7 +2112,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
|
||||
absl::Span<HloInstruction* const> operands(custom_call->operands());
|
||||
absl::string_view custom_call_target(custom_call->custom_call_target());
|
||||
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
|
||||
llvm::AllocaInst* operands_alloca =
|
||||
@ -2233,7 +2232,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
|
||||
}
|
||||
|
||||
StatusOr<bool> IrEmitter::EmitFastConcatenate(
|
||||
HloInstruction* concatenate, gtl::ArraySlice<HloInstruction*> operands,
|
||||
HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
|
||||
string* failure_reason) {
|
||||
if (ShouldEmitParallelLoopFor(*concatenate)) {
|
||||
*failure_reason =
|
||||
@ -2369,7 +2368,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
|
||||
gtl::ArraySlice<HloInstruction*> operands(concatenate->operands());
|
||||
absl::Span<HloInstruction* const> operands(concatenate->operands());
|
||||
string failure_reason;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool successful,
|
||||
@ -2800,8 +2799,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
|
||||
|
||||
Status IrEmitter::ElementTypesSameAndSupported(
|
||||
const HloInstruction& instruction,
|
||||
gtl::ArraySlice<const HloInstruction*> operands,
|
||||
gtl::ArraySlice<PrimitiveType> supported_types) {
|
||||
absl::Span<const HloInstruction* const> operands,
|
||||
absl::Span<const PrimitiveType> supported_types) {
|
||||
for (auto operand : operands) {
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
|
||||
@ -2831,8 +2830,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
llvm::Value* IrEmitter::EmitThreadLocalCall(
|
||||
const HloComputation& callee,
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
|
||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||
absl::string_view name) {
|
||||
const Shape& return_shape = callee.root_instruction()->shape();
|
||||
|
||||
|
@ -111,7 +111,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
// Emit code to map one element according to `map_instr`.
|
||||
llvm::Value* EmitElementalMap(
|
||||
const HloMapInstruction& map_instr,
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
|
||||
absl::Span<llvm::Value* const> elemental_operands,
|
||||
absl::string_view name);
|
||||
|
||||
protected:
|
||||
@ -252,9 +252,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
//
|
||||
// `parameters` holds the *scalar values* that need to be passed to the
|
||||
// callee. The return value is the scalar returned by the callee.
|
||||
llvm::Value* EmitThreadLocalCall(
|
||||
const HloComputation& callee,
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
|
||||
llvm::Value* EmitThreadLocalCall(const HloComputation& callee,
|
||||
absl::Span<llvm::Value* const> parameters,
|
||||
absl::string_view name);
|
||||
|
||||
// Emits a call to a "global" function (e.g. to the computation nested within
|
||||
@ -271,8 +270,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
// match and are of one of the given supported types.
|
||||
Status ElementTypesSameAndSupported(
|
||||
const HloInstruction& instruction,
|
||||
tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
|
||||
tensorflow::gtl::ArraySlice<PrimitiveType> supported_types);
|
||||
absl::Span<const HloInstruction* const> operands,
|
||||
absl::Span<const PrimitiveType> supported_types);
|
||||
|
||||
// Emit IR to perform a computation for every element in the given target op.
|
||||
// This produces a series of nested loops (one for each dimension of the op's
|
||||
@ -319,9 +318,11 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
// concepts that generalize over other vectorizable operations. We should
|
||||
// consider pulling out these abstractions into a VectorizingIrEmitter or
|
||||
// something similar.
|
||||
StatusOr<bool> EmitVectorizedReduce(
|
||||
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function,
|
||||
StatusOr<bool> EmitVectorizedReduce(HloInstruction* reduce,
|
||||
HloInstruction* arg,
|
||||
HloInstruction* init_value,
|
||||
absl::Span<const int64> dimensions,
|
||||
HloComputation* function,
|
||||
string* failure_reason);
|
||||
|
||||
// We'd like to keep one or two one cache-line's worth of data in registers
|
||||
@ -372,15 +373,14 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
const ReductionGenerator& reduction_generator,
|
||||
const llvm_ir::IrArray::Index& output_index,
|
||||
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
|
||||
HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
HloInstruction* arg, absl::Span<const int64> dimensions,
|
||||
unsigned element_alignment);
|
||||
|
||||
// Tries to emit a fast concatenate operation using memcpy. Returns true if
|
||||
// successful, and false on failure. On failure, sets "failure_reason" to a
|
||||
// string describing why it could not emit a fast concatenate.
|
||||
StatusOr<bool> EmitFastConcatenate(
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
|
||||
StatusOr<bool> EmitFastConcatenate(HloInstruction* concatenate,
|
||||
absl::Span<HloInstruction* const> operands,
|
||||
string* failure_reason);
|
||||
|
||||
// Emits LLVM IR to transfer "element_count" elements of type "primitive_type"
|
||||
|
@ -200,10 +200,10 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
|
||||
// Returns an array of compute function call arguments (including parameter
|
||||
// address buffer).
|
||||
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
|
||||
llvm::IRBuilder<>* b, absl::string_view name,
|
||||
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
|
||||
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
|
||||
absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
|
||||
absl::string_view name, llvm::Value* return_value_buffer,
|
||||
llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg,
|
||||
llvm::Value* profile_counters_arg) {
|
||||
llvm::Value* parameter_addresses_buffer;
|
||||
|
||||
if (parameter_addresses.empty()) {
|
||||
|
@ -115,10 +115,10 @@ class IrFunction {
|
||||
|
||||
// Returns an array of compute function call argument ir values.
|
||||
std::vector<llvm::Value*> GetArrayFunctionCallArguments(
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
|
||||
llvm::IRBuilder<>* b, absl::string_view name,
|
||||
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
|
||||
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg);
|
||||
absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
|
||||
absl::string_view name, llvm::Value* return_value_buffer,
|
||||
llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg,
|
||||
llvm::Value* profile_counters_arg);
|
||||
|
||||
// Emits a call to a runtime fork/join function which dispatches parallel
|
||||
// calls to 'parallel_function' (and joins threads before returning).
|
||||
|
@ -428,7 +428,7 @@ std::vector<llvm::Value*> TileVariable::Get() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
void TileVariable::Set(tensorflow::gtl::ArraySlice<llvm::Value*> value) {
|
||||
void TileVariable::Set(absl::Span<llvm::Value* const> value) {
|
||||
CHECK_EQ(value.size(), storage_.size());
|
||||
for (int64 i = 0, e = value.size(); i < e; i++) {
|
||||
storage_[i].Set(value[i]);
|
||||
|
@ -324,7 +324,7 @@ class TileVariable {
|
||||
std::vector<llvm::Value*> initial_value);
|
||||
|
||||
std::vector<llvm::Value*> Get() const;
|
||||
void Set(tensorflow::gtl::ArraySlice<llvm::Value*> value);
|
||||
void Set(absl::Span<llvm::Value* const> value);
|
||||
|
||||
private:
|
||||
std::vector<VectorVariable> storage_;
|
||||
|
@ -37,7 +37,7 @@ void XfeedQueueManager::Reset() {
|
||||
}
|
||||
|
||||
void XfeedQueueManager::EnqueueBuffersAtomically(
|
||||
tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers) {
|
||||
absl::Span<XfeedBuffer* const> buffers) {
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
bool was_empty = enqueued_buffers_.empty();
|
||||
for (XfeedBuffer* b : buffers) {
|
||||
|
@ -63,8 +63,7 @@ class XfeedQueueManager {
|
||||
// called when the buffer will no longer be accessed by the XfeedManager,
|
||||
// either as a result of a call to Reset or because the runtime has dequeued
|
||||
// and used the buffer.
|
||||
void EnqueueBuffersAtomically(
|
||||
tensorflow::gtl::ArraySlice<XfeedBuffer*> buffers);
|
||||
void EnqueueBuffersAtomically(absl::Span<XfeedBuffer* const> buffers);
|
||||
|
||||
// Blocks until the queue is non-empty, then returns the buffer at the head of
|
||||
// the queue. Sets the current buffer to be the returned buffer. It is an
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
|
||||
const se::Platform* platform,
|
||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors)
|
||||
absl::Span<se::StreamExecutor* const> stream_executors)
|
||||
: DeviceMemoryAllocator(platform),
|
||||
stream_executors_(stream_executors.begin(), stream_executors.end()) {}
|
||||
|
||||
|
@ -80,7 +80,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
|
||||
public:
|
||||
StreamExecutorMemoryAllocator(
|
||||
const se::Platform* platform,
|
||||
tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors);
|
||||
absl::Span<se::StreamExecutor* const> stream_executors);
|
||||
|
||||
StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
|
||||
bool retry_on_failure) override;
|
||||
|
@ -856,7 +856,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
||||
auto getFloat = [&](const float f) {
|
||||
return llvm::ConstantFP::get(b_->getFloatTy(), f);
|
||||
};
|
||||
auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
|
||||
auto multiply_add = [&](absl::Span<const float> coefficients,
|
||||
llvm::Value* w) {
|
||||
llvm::Value* p = getFloat(coefficients.front());
|
||||
coefficients.remove_prefix(1);
|
||||
@ -893,7 +893,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
||||
SetToFirstInsertPoint(if_data.true_block, b_);
|
||||
{
|
||||
llvm::Value* lw = FSub(w, getFloat(2.5f));
|
||||
tensorflow::gtl::ArraySlice<float> lq{
|
||||
absl::Span<const float> lq{
|
||||
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
|
||||
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
|
||||
-0.00417768164f, 0.246640727f, 1.50140941f};
|
||||
@ -908,7 +908,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
||||
module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
|
||||
|
||||
llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
|
||||
tensorflow::gtl::ArraySlice<float> gq{
|
||||
absl::Span<const float> gq{
|
||||
-0.000200214257f, 0.000100950558f, 0.00134934322f,
|
||||
-0.00367342844f, 0.00573950773f, -0.0076224613f,
|
||||
0.00943887047f, 1.00167406f, 2.83297682f};
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user