From 6f879f891abe2e267c5cf512d034d7c3641cfdb0 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Thu, 30 Aug 2018 16:03:10 -0700 Subject: [PATCH] [XLA] Rename all (Mutable)ArraySlice to absl::Span. PiperOrigin-RevId: 210998142 --- .../compiler/aot/embedded_protocol_buffers.cc | 2 +- .../compiler/aot/embedded_protocol_buffers.h | 2 +- tensorflow/compiler/jit/deadness_analysis.cc | 50 +- .../compiler/jit/deadness_analysis_internal.h | 2 +- .../jit/encapsulate_subgraphs_pass_test.cc | 145 +++-- .../compiler/jit/partially_decluster_pass.cc | 2 +- tensorflow/compiler/tests/randomized_tests.cc | 18 +- .../tf2xla/functionalize_control_flow_test.cc | 24 +- .../tf2xla/kernels/batchtospace_op.cc | 4 +- .../compiler/tf2xla/kernels/binary_ops.cc | 30 +- .../compiler/tf2xla/kernels/cwise_ops.h | 4 +- tensorflow/compiler/tf2xla/kernels/diag_op.cc | 4 +- .../tf2xla/kernels/image_resize_ops.cc | 6 +- .../compiler/tf2xla/kernels/random_ops.cc | 3 +- .../compiler/tf2xla/kernels/select_op.cc | 2 +- .../tf2xla/kernels/spacetobatch_op.cc | 4 +- .../tf2xla/kernels/tensor_array_ops.cc | 2 +- tensorflow/compiler/tf2xla/lib/cholesky.cc | 2 +- tensorflow/compiler/tf2xla/lib/qr.cc | 12 +- tensorflow/compiler/tf2xla/lib/scatter.cc | 6 +- tensorflow/compiler/tf2xla/lib/util.cc | 20 +- tensorflow/compiler/tf2xla/lib/util.h | 22 +- tensorflow/compiler/tf2xla/lib/while_loop.cc | 8 +- tensorflow/compiler/tf2xla/lib/while_loop.h | 10 +- tensorflow/compiler/tf2xla/literal_util.cc | 5 +- tensorflow/compiler/tf2xla/literal_util.h | 5 +- .../compiler/tf2xla/literal_util_test.cc | 4 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 12 +- tensorflow/compiler/tf2xla/xla_compiler.h | 8 +- tensorflow/compiler/tf2xla/xla_helpers.cc | 2 +- tensorflow/compiler/tf2xla/xla_helpers.h | 2 +- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 2 +- tensorflow/compiler/tf2xla/xla_op_kernel.h | 2 +- tensorflow/compiler/tf2xla/xla_op_registry.cc | 8 +- tensorflow/compiler/tf2xla/xla_op_registry.h | 8 +- tensorflow/compiler/xla/array.h | 28 +- tensorflow/compiler/xla/array4d_test.cc | 22 +- tensorflow/compiler/xla/array_test.cc | 2 +- tensorflow/compiler/xla/client/client.cc | 8 +- tensorflow/compiler/xla/client/client.h | 6 +- .../xla/client/compile_only_client.cc | 2 +- .../compiler/xla/client/compile_only_client.h | 2 +- .../xla/client/executable_build_options.h | 2 +- tensorflow/compiler/xla/client/lib/math.cc | 3 +- tensorflow/compiler/xla/client/lib/math.h | 3 +- tensorflow/compiler/xla/client/lib/numeric.cc | 4 +- tensorflow/compiler/xla/client/lib/pooling.cc | 56 +- tensorflow/compiler/xla/client/lib/pooling.h | 29 +- .../compiler/xla/client/lib/pooling_test.cc | 6 +- .../compiler/xla/client/local_client.cc | 10 +- tensorflow/compiler/xla/client/local_client.h | 13 +- tensorflow/compiler/xla/client/padding.cc | 13 +- tensorflow/compiler/xla/client/padding.h | 13 +- tensorflow/compiler/xla/client/xla_builder.cc | 401 ++++++------ tensorflow/compiler/xla/client/xla_builder.h | 572 ++++++++---------- tensorflow/compiler/xla/index_util.cc | 15 +- tensorflow/compiler/xla/index_util.h | 12 +- tensorflow/compiler/xla/layout_util.cc | 12 +- tensorflow/compiler/xla/layout_util.h | 13 +- tensorflow/compiler/xla/layout_util_test.cc | 6 +- tensorflow/compiler/xla/literal.cc | 90 ++- tensorflow/compiler/xla/literal.h | 166 +++-- tensorflow/compiler/xla/literal_comparison.cc | 33 +- tensorflow/compiler/xla/literal_test.cc | 23 +- tensorflow/compiler/xla/literal_util.cc | 14 +- tensorflow/compiler/xla/literal_util.h | 47 +- .../compiler/xla/packed_literal_reader.cc | 2 +- .../xla/python/local_computation_builder.cc | 82 ++- .../xla/python/local_computation_builder.h | 64 +- .../xla/python/local_computation_builder.i | 20 +- tensorflow/compiler/xla/reference_util.cc | 90 ++- tensorflow/compiler/xla/reference_util.h | 58 +- .../xla/service/algebraic_simplifier.cc | 10 +- .../xla/service/algebraic_simplifier_test.cc | 6 +- tensorflow/compiler/xla/service/backend.cc | 10 +- tensorflow/compiler/xla/service/backend.h | 2 +- .../xla/service/bfloat16_normalization.cc | 6 +- .../xla/service/bfloat16_propagation.cc | 2 +- .../xla/service/buffer_assignment_test.cc | 2 +- .../xla/service/compile_only_service.cc | 2 +- .../xla/service/compile_only_service.h | 4 +- .../compiler/xla/service/copy_insertion.cc | 2 +- .../xla/service/cpu/buffer_info_util.cc | 2 +- .../xla/service/cpu/buffer_info_util.h | 2 +- .../xla/service/cpu/cpu_executable.cc | 14 +- .../compiler/xla/service/cpu/cpu_executable.h | 17 +- .../xla/service/cpu/cpu_transfer_manager.cc | 7 +- .../xla/service/cpu/cpu_transfer_manager.h | 5 +- .../xla/service/cpu/dot_op_emitter.cc | 2 +- .../compiler/xla/service/cpu/ir_emitter.cc | 24 +- .../compiler/xla/service/cpu/ir_emitter.h | 32 +- .../compiler/xla/service/cpu/ir_function.cc | 8 +- .../compiler/xla/service/cpu/ir_function.h | 8 +- .../xla/service/cpu/vector_support_library.cc | 2 +- .../xla/service/cpu/vector_support_library.h | 2 +- .../compiler/xla/service/cpu/xfeed_manager.cc | 2 +- .../compiler/xla/service/cpu/xfeed_manager.h | 3 +- .../xla/service/device_memory_allocator.cc | 2 +- .../xla/service/device_memory_allocator.h | 2 +- .../xla/service/elemental_ir_emitter.cc | 6 +- .../xla/service/elemental_ir_emitter_test.cc | 3 +- tensorflow/compiler/xla/service/executable.cc | 7 +- tensorflow/compiler/xla/service/executable.h | 13 +- .../compiler/xla/service/gather_expander.cc | 5 +- .../xla/service/generic_transfer_manager.cc | 5 +- .../xla/service/generic_transfer_manager.h | 6 +- .../gpu/cudnn_convolution_algorithm_picker.cc | 5 +- .../xla/service/gpu/elemental_ir_emitter.cc | 26 +- .../xla/service/gpu/elemental_ir_emitter.h | 30 +- .../compiler/xla/service/gpu/fft_thunk.cc | 3 +- .../compiler/xla/service/gpu/fft_thunk.h | 2 +- .../xla/service/gpu/gpu_executable.cc | 4 +- .../compiler/xla/service/gpu/gpu_executable.h | 4 +- .../xla/service/gpu/hlo_to_ir_bindings.cc | 4 +- .../xla/service/gpu/hlo_to_ir_bindings.h | 4 +- .../xla/service/gpu/ir_emission_utils.cc | 2 +- .../xla/service/gpu/ir_emission_utils.h | 2 +- .../compiler/xla/service/gpu/ir_emitter.cc | 6 +- .../compiler/xla/service/gpu/ir_emitter.h | 8 +- .../xla/service/gpu/ir_emitter_unnested.cc | 68 +-- .../xla/service/gpu/ir_emitter_unnested.h | 62 +- .../compiler/xla/service/gpu/kernel_thunk.cc | 8 +- .../compiler/xla/service/gpu/kernel_thunk.h | 2 +- .../xla/service/gpu/pad_for_tensor_cores.cc | 3 +- .../xla/service/gpu/parallel_loop_emitter.cc | 2 +- .../xla/service/gpu/parallel_loop_emitter.h | 9 +- .../compiler/xla/service/gpu/tuple_thunk.h | 3 +- tensorflow/compiler/xla/service/hlo_buffer.h | 2 +- .../compiler/xla/service/hlo_computation.cc | 4 +- .../compiler/xla/service/hlo_computation.h | 4 +- .../xla/service/hlo_constant_folding_test.cc | 6 +- .../xla/service/hlo_creation_utils.cc | 42 +- .../compiler/xla/service/hlo_creation_utils.h | 39 +- .../xla/service/hlo_creation_utils_test.cc | 10 +- .../xla/service/hlo_dataflow_analysis.cc | 5 +- .../xla/service/hlo_dataflow_analysis.h | 2 +- .../compiler/xla/service/hlo_evaluator.cc | 86 +-- .../compiler/xla/service/hlo_evaluator.h | 12 +- .../xla/service/hlo_evaluator_test.cc | 4 +- .../xla/service/hlo_evaluator_typed_visitor.h | 100 ++- .../compiler/xla/service/hlo_instruction.cc | 60 +- .../compiler/xla/service/hlo_instruction.h | 60 +- .../compiler/xla/service/hlo_instructions.cc | 178 +++--- .../compiler/xla/service/hlo_instructions.h | 205 +++---- tensorflow/compiler/xla/service/hlo_module.cc | 2 +- tensorflow/compiler/xla/service/hlo_module.h | 2 +- .../xla/service/hlo_module_group_util.cc | 6 +- .../xla/service/hlo_module_group_util.h | 7 +- .../compiler/xla/service/hlo_module_test.cc | 2 +- tensorflow/compiler/xla/service/hlo_parser.cc | 6 +- .../compiler/xla/service/hlo_reachability.cc | 8 +- .../compiler/xla/service/hlo_reachability.h | 11 +- .../xla/service/hlo_rematerialization.cc | 4 +- tensorflow/compiler/xla/service/hlo_runner.cc | 21 +- tensorflow/compiler/xla/service/hlo_runner.h | 12 +- .../xla/service/hlo_scheduling_test.cc | 2 +- .../compiler/xla/service/hlo_sharding.cc | 12 +- .../compiler/xla/service/hlo_sharding.h | 4 +- .../xla/service/hlo_sharding_metadata.cc | 2 +- .../compiler/xla/service/hlo_sharding_test.cc | 4 +- tensorflow/compiler/xla/service/hlo_value.cc | 7 +- tensorflow/compiler/xla/service/hlo_value.h | 10 +- .../compiler/xla/service/hlo_verifier.cc | 3 +- .../xla/service/indexed_array_analysis.cc | 36 +- .../xla/service/indexed_array_analysis.h | 12 +- .../xla/service/instruction_fusion.cc | 2 +- .../compiler/xla/service/instruction_fusion.h | 2 +- .../xla/service/interpreter/executable.cc | 4 +- .../xla/service/interpreter/executable.h | 4 +- .../xla/service/interpreter/executor.h | 2 +- .../llvm_ir/dynamic_update_slice_util.cc | 16 +- .../llvm_ir/dynamic_update_slice_util.h | 13 +- .../xla/service/llvm_ir/fused_ir_emitter.cc | 2 +- .../xla/service/llvm_ir/fused_ir_emitter.h | 4 +- .../compiler/xla/service/llvm_ir/ir_array.cc | 20 +- .../compiler/xla/service/llvm_ir/ir_array.h | 31 +- .../service/llvm_ir/kernel_support_library.h | 2 +- .../xla/service/llvm_ir/kernel_tiling.cc | 5 +- .../xla/service/llvm_ir/kernel_tiling.h | 4 +- .../compiler/xla/service/llvm_ir/llvm_loop.cc | 2 +- .../compiler/xla/service/llvm_ir/llvm_loop.h | 2 +- .../compiler/xla/service/llvm_ir/llvm_util.cc | 9 +- .../compiler/xla/service/llvm_ir/llvm_util.h | 11 +- .../xla/service/llvm_ir/loop_emitter.cc | 2 +- .../xla/service/llvm_ir/loop_emitter.h | 3 +- .../compiler/xla/service/llvm_ir/tuple_ops.cc | 3 +- .../compiler/xla/service/llvm_ir/tuple_ops.h | 3 +- .../compiler/xla/service/local_service.cc | 2 +- .../compiler/xla/service/local_service.h | 2 +- .../xla/service/multi_output_fusion.cc | 2 +- .../xla/service/multi_output_fusion.h | 2 +- .../compiler/xla/service/scatter_expander.cc | 3 +- tensorflow/compiler/xla/service/service.cc | 33 +- tensorflow/compiler/xla/service/service.h | 24 +- .../compiler/xla/service/shape_inference.cc | 89 ++- .../compiler/xla/service/shape_inference.h | 64 +- .../xla/service/shape_inference_test.cc | 9 +- .../compiler/xla/service/transfer_manager.h | 5 +- .../xla/service/tuple_points_to_analysis.cc | 2 +- .../service/tuple_points_to_analysis_test.cc | 12 +- tensorflow/compiler/xla/service/tuple_util.cc | 2 +- tensorflow/compiler/xla/service/tuple_util.h | 2 +- .../xla/service/while_loop_analysis.cc | 3 +- tensorflow/compiler/xla/service/while_util.cc | 2 +- tensorflow/compiler/xla/service/while_util.h | 2 +- tensorflow/compiler/xla/shape_util.cc | 31 +- tensorflow/compiler/xla/shape_util.h | 72 ++- tensorflow/compiler/xla/shape_util_test.cc | 14 +- tensorflow/compiler/xla/sparse_index_array.cc | 19 +- tensorflow/compiler/xla/sparse_index_array.h | 19 +- .../xla/tests/array_elementwise_ops_test.cc | 6 +- .../xla/tests/broadcast_simple_test.cc | 20 +- .../xla/tests/client_library_test_base.cc | 49 +- .../xla/tests/client_library_test_base.h | 114 ++-- .../xla/tests/compilation_cache_test.cc | 9 +- tensorflow/compiler/xla/tests/copy_test.cc | 8 +- .../compiler/xla/tests/deallocation_test.cc | 2 +- .../xla/tests/deconstruct_tuple_test.cc | 2 +- .../compiler/xla/tests/dynamic_ops_test.cc | 14 +- .../compiler/xla/tests/floor_ceil_test.cc | 4 +- tensorflow/compiler/xla/tests/fusion_test.cc | 9 +- .../xla/tests/gather_operation_test.cc | 3 +- tensorflow/compiler/xla/tests/half_test.cc | 5 +- .../compiler/xla/tests/hlo_test_base.cc | 19 +- tensorflow/compiler/xla/tests/hlo_test_base.h | 15 +- .../compiler/xla/tests/literal_test_util.h | 8 +- .../xla/tests/local_client_test_base.cc | 8 +- .../xla/tests/local_client_test_base.h | 8 +- .../xla/tests/multioutput_fusion_test.cc | 1 - tensorflow/compiler/xla/tests/pred_test.cc | 2 +- tensorflow/compiler/xla/tests/prng_test.cc | 16 +- tensorflow/compiler/xla/tests/reduce_test.cc | 15 +- .../compiler/xla/tests/reduce_window_test.cc | 16 +- tensorflow/compiler/xla/tests/reshape_test.cc | 44 +- tensorflow/compiler/xla/tests/reverse_test.cc | 25 +- .../tests/round_trip_packed_literal_test.cc | 9 +- .../xla/tests/scalar_computations_test.cc | 10 +- tensorflow/compiler/xla/tests/scatter_test.cc | 3 +- .../xla/tests/select_and_scatter_test.cc | 4 +- tensorflow/compiler/xla/tests/slice_test.cc | 2 +- tensorflow/compiler/xla/tests/test_utils.cc | 10 +- tensorflow/compiler/xla/tests/tuple_test.cc | 2 +- .../xla/tests/xla_hlo_profile_test.cc | 2 +- .../compiler/xla/text_literal_writer.cc | 3 +- .../tools/dumped_computation_to_graphviz.cc | 4 +- .../dumped_computation_to_operation_list.cc | 4 +- .../xla/tools/dumped_computation_to_text.cc | 4 +- .../dumped_computation_to_tf_graphdef.cc | 4 +- .../compiler/xla/tools/replay_computation.cc | 4 +- .../compiler/xla/tools/show_signature.cc | 4 +- tensorflow/compiler/xla/util.cc | 19 +- tensorflow/compiler/xla/util.h | 105 ++-- tensorflow/compiler/xla/window_util.cc | 4 +- tensorflow/compiler/xla/window_util.h | 4 +- 254 files changed, 2330 insertions(+), 2712 deletions(-) diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 1401aae7586..f1e8e5c0848 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -111,7 +111,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) { StatusOr CreateEmbeddedProtocolBuffers( StringPiece target_triple, - gtl::ArraySlice protobufs_to_embed) { + absl::Span protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 4e194a6aba9..1dbdfa33e4a 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -84,7 +84,7 @@ struct ProtobufToEmbed { // EmbeddedProtocolBuffers instance. StatusOr CreateEmbeddedProtocolBuffers( StringPiece target_triple, - gtl::ArraySlice protobufs_to_embed); + absl::Span protobufs_to_embed); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index fe28502f69d..82aa03810bc 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -108,7 +108,7 @@ class Predicate { virtual string ToString() const = 0; int64 hash() const { return hash_; } - virtual gtl::ArraySlice GetOperands() const = 0; + virtual absl::Span GetOperands() const = 0; virtual Kind kind() const = 0; virtual ~Predicate() {} @@ -129,7 +129,7 @@ class Predicate { }; int64 HashPredicateSequence(Predicate::Kind kind, - gtl::ArraySlice preds) { + absl::Span preds) { int64 hash = ::tensorflow::hash()(kind); for (Predicate* pred : preds) { hash = Hash64Combine(hash, pred->hash()); @@ -159,8 +159,10 @@ class AndPredicate : public Predicate { Kind kind() const override { return Kind::kAnd; } - gtl::ArraySlice GetOperands() const override { return operands_; } - gtl::ArraySlice operands() const { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } + absl::Span operands() const { return operands_; } private: std::vector operands_; @@ -187,8 +189,10 @@ class OrPredicate : public Predicate { } Kind kind() const override { return Kind::kOr; } - gtl::ArraySlice GetOperands() const override { return operands_; } - gtl::ArraySlice operands() const { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } + absl::Span operands() const { return operands_; } private: std::vector operands_; @@ -207,7 +211,9 @@ class NotPredicate : public Predicate { Kind kind() const override { return Kind::kNot; } Predicate* operand() const { return operands_[0]; } - gtl::ArraySlice GetOperands() const override { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } private: std::array operands_; @@ -240,7 +246,9 @@ class AndRecurrencePredicate : public Predicate { Kind kind() const override { return Kind::kAndRecurrence; } - gtl::ArraySlice GetOperands() const override { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } private: std::array operands_; @@ -264,7 +272,7 @@ class SymbolPredicate : public Predicate { } Kind kind() const override { return Kind::kSymbol; } - gtl::ArraySlice GetOperands() const override { return {}; } + absl::Span GetOperands() const override { return {}; } // If `must_be_true()` is true this SymbolPredicate represents the proposition // "tensor_id() is live and evaluates to true". @@ -313,11 +321,11 @@ template // them. class PredicateFactory { public: - Predicate* MakeAndPredicate(gtl::ArraySlice operands) { + Predicate* MakeAndPredicate(absl::Span operands) { return MakeAndOrImpl(operands, /*is_and=*/true); } - Predicate* MakeOrPredicate(gtl::ArraySlice operands) { + Predicate* MakeOrPredicate(absl::Span operands) { return MakeAndOrImpl(operands, /*is_and=*/false); } @@ -374,7 +382,7 @@ class PredicateFactory { new PredicateT(std::forward(args)...)); } - Predicate* MakeAndOrImpl(gtl::ArraySlice operands, bool is_and); + Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); // Predicate instances are interned, meaning that there is only a single // instance of a Predicate object with a given content. This makes checking @@ -387,7 +395,7 @@ class PredicateFactory { // for the owning pointers to predicate instances. using SignatureForAndOr = - std::pair>; + std::pair>; using SignatureForNot = Predicate*; using SignatureForAndRec = std::pair; using SignatureForSymbol = std::pair; @@ -422,8 +430,8 @@ class PredicateFactory { }; // Common code to create AndPredicate or OrPredicate instances. -Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, - bool is_and) { +Predicate* PredicateFactory::MakeAndOrImpl( + absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; gtl::FlatSet simplified_ops_set; @@ -474,7 +482,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, // NB! Because we'll use a non-owning reference to simplified_ops in the // key for interned_and_or_instances_ we need to be careful to std::move() // it all the way through. - gtl::ArraySlice operands_slice = simplified_ops; + absl::Span operands_slice = simplified_ops; std::unique_ptr new_pred = is_and ? Make(std::move(simplified_ops)) : Make(std::move(simplified_ops)); @@ -496,7 +504,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} Status Populate(); - Status PopulateWithReversePostOrder(gtl::ArraySlice rpo); + Status PopulateWithReversePostOrder(absl::Span rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; gtl::FlatMap PredicateMapAsString() const; @@ -527,7 +535,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { } } - void SetPredicate(Node* n, gtl::ArraySlice output_idxs, Predicate* pred, + void SetPredicate(Node* n, absl::Span output_idxs, Predicate* pred, std::vector* should_revisit) { for (int output_idx : output_idxs) { SetPredicate(n, output_idx, pred, should_revisit); @@ -625,7 +633,7 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, } std::vector and_ops; - gtl::ArraySlice recurrent_pred_ops = + absl::Span recurrent_pred_ops = backedge_predicate->GetOperands(); bool found_sym = false; @@ -784,7 +792,7 @@ Status DeadnessAnalysisImpl::Populate() { } Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( - gtl::ArraySlice rpo) { + absl::Span rpo) { // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // @@ -924,7 +932,7 @@ Status ComputePredicates(const Graph& graph, } Status ComputePredicates(const Graph& graph, - gtl::ArraySlice reverse_post_order, + absl::Span reverse_post_order, PredicateMapTy* out_predicate_map) { DeadnessAnalysisImpl impl(&graph); TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order)); diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 401d6e406ab..3df2679c629 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -32,7 +32,7 @@ Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); // specified in `reverse_post_order` which must be a valid RPO for the graph // minus NextIteration->Merge edges. Status ComputePredicates(const Graph& graph, - gtl::ArraySlice reverse_post_order, + absl::Span reverse_post_order, PredicateMapTy* out_predicate_map); } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index b3600fc48b9..7bc0ef03030 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -379,7 +379,7 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTestShaped", opts); } -Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, +Node* KnownShapeBase(DataType dtype, absl::Span shape, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", @@ -394,7 +394,7 @@ Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, .FinalizeBuilder(&node_builder); } -Node* KnownShape(const gtl::ArraySlice& shape, +Node* KnownShape(absl::Span shape, const GraphDefBuilder::Options& opts) { return KnownShapeBase(DT_FLOAT, shape, opts); } @@ -417,8 +417,7 @@ Node* KeyPlaceholder(const string& call_node, } Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, - const string& oc_cluster, - const gtl::ArraySlice& dtypes, + const string& oc_cluster, absl::Span dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; string key = @@ -892,13 +891,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "c:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, @@ -1038,26 +1037,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"F", "outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1190,13 +1189,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1213,13 +1212,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"G:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -1364,13 +1363,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1386,13 +1385,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"G:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F2_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"i_0_retval", "I:o:0"}}); @@ -1495,13 +1494,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {}, - {{"Tinputs", gtl::ArraySlice({})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1579,13 +1578,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {}, - {{"Tinputs", gtl::ArraySlice({})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1661,12 +1660,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1742,12 +1741,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1846,13 +1845,13 @@ TEST(EncapsulateSubgraphsTest, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}}, }, {{"h_0_retval", "H:o:0"}}); @@ -1955,13 +1954,13 @@ TEST(EncapsulateSubgraphsTest, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"h_0_retval", "H:o:0"}}); @@ -2066,37 +2065,37 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute", - "outside_compilation_O2_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}, {"key", "host_compute_channel_F1_O3"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}}, {"outside_compilation_O1_host_compute", "outside_compilation_O2_host_compute"}}}, @@ -2272,13 +2271,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"c:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 3a9a8c4988a..a8f09bfa503 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace { Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, - gtl::ArraySlice post_order) { + absl::Span post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to // avoid the device-host copy we'd otherwise need. diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 1b8198dba8e..0faf0fd8edf 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -275,13 +275,13 @@ class OpTest : public ::testing::Test { // Select a random element from 'candidates'. template - T Choose(gtl::ArraySlice candidates); + T Choose(absl::Span candidates); static constexpr int kDefaultMaxRank = 5; static constexpr int64 kDefaultMaxDimensionSize = 256LL; // Returns true if 'dims' have a size less than tf_xla_max_tensor_size. - bool TensorSizeIsOk(gtl::ArraySlice dims); + bool TensorSizeIsOk(absl::Span dims); // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); @@ -307,11 +307,11 @@ class OpTest : public ::testing::Test { // of the type's range. If the shape is omitted, a random shape is used. // TODO(phawkins): generalize this code to a caller-supplied distribution. Tensor RandomTensor(DataType dtype, bool needs_unique_values, - gtl::ArraySlice shape); + absl::Span shape); Tensor RandomTensor(DataType dtype); // Like RandomTensor, but uses values >= 0. - Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomNonNegativeTensor(DataType dtype, absl::Span shape); Tensor RandomNonNegativeTensor(DataType dtype); // Returns a random subset of the integers in the range [0, rank), suitable @@ -415,7 +415,7 @@ void OpTest::Repeatedly(const std::function& fn) { } template -T OpTest::Choose(gtl::ArraySlice candidates) { +T OpTest::Choose(absl::Span candidates) { std::uniform_int_distribution d(0, candidates.size() - 1); return candidates[d(generator())]; } @@ -425,7 +425,7 @@ int64 OpTest::RandomDim(int64 min, int64 max) { return size_distribution(generator()); } -bool OpTest::TensorSizeIsOk(gtl::ArraySlice dims) { +bool OpTest::TensorSizeIsOk(absl::Span dims) { int64 size = 1LL; for (int64 dim : dims) { size *= dim; @@ -451,7 +451,7 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, } Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, - gtl::ArraySlice shape) { + absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { @@ -548,7 +548,7 @@ Tensor OpTest::RandomTensor(DataType dtype) { } Tensor OpTest::RandomNonNegativeTensor(DataType dtype, - gtl::ArraySlice shape) { + absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { @@ -1884,7 +1884,7 @@ TEST_F(OpTest, DynamicStitch) { for (int i = 0; i < n; ++i) { TensorShape shape(index_dims[i]); Tensor t = test::AsTensor( - gtl::ArraySlice(indices).subspan(pos, shape.num_elements()), + absl::Span(indices).subspan(pos, shape.num_elements()), shape); builder.Input(t); pos += t.NumElements(); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index cc52057f214..c068a4110c0 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -805,11 +805,11 @@ TEST(FunctionalizeControlFlow, Complex) { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); - auto one = - ops::Const(scope.WithOpName("outer/inner/One") - .WithControlDependencies( - gtl::ArraySlice{assign.operation}), - 1); + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); auto add_j = ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); @@ -823,7 +823,7 @@ TEST(FunctionalizeControlFlow, Complex) { scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(gtl::ArraySlice{ + .WithControlDependencies(absl::Span{ exit_j.output.op(), exit_k.output.op()}), identity_i, one_outer); auto next_iteration_i = @@ -929,7 +929,7 @@ TEST(FunctionalizeControlFlow, Complex) { scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(gtl::ArraySlice{ + .WithControlDependencies(absl::Span{ while_op[0].op(), while_op[1].op()}), identity_i, one_outer); @@ -991,11 +991,11 @@ TEST(FunctionalizeControlFlow, Complex) { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - auto one = - ops::Const(scope.WithOpName("outer/inner/One") - .WithControlDependencies( - gtl::ArraySlice{assign.operation}), - 1); + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); auto add_j = ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 48f2a005ab1..edced6bc0e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -23,7 +23,7 @@ namespace { void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, DataType input_dtype, const TensorShape& input_tensor_shape, - gtl::ArraySlice block_shape, + absl::Span block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); const gtl::InlinedVector input_shape = @@ -34,7 +34,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, ctx, input_rank >= 1 + block_rank, errors::InvalidArgument("input rank should be >= ", 1 + block_rank, " instead of ", input_rank)); - gtl::ArraySlice remainder_shape(input_shape); + absl::Span remainder_shape(input_shape); remainder_shape.remove_prefix(1 + block_rank); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 2c328102e0b..df17da4c1ca 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -30,21 +30,21 @@ namespace { // A subclass of a XlaBinaryOp must build the computation that // describes the (tensor,tensor)->tensor function to apply to each element of // the input. -#define XLA_MAKE_BINARY(NAME, HLO) \ - class NAME##Op : public XlaBinaryOp { \ - public: \ - explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ - xla::XlaOp Computation( \ - XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ - const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, \ - const gtl::ArraySlice& rhs_shape, \ - const BCast& broadcast_helper, \ - const std::vector& extend_dimensions) override { \ - xla::XlaBuilder* b = ctx->builder(); \ - (void)b; \ - return HLO; \ - } \ - }; \ +#define XLA_MAKE_BINARY(NAME, HLO) \ + class NAME##Op : public XlaBinaryOp { \ + public: \ + explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ + xla::XlaOp Computation( \ + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ + const absl::Span& lhs_shape, const xla::XlaOp& rhs, \ + const absl::Span& rhs_shape, \ + const BCast& broadcast_helper, \ + const std::vector& extend_dimensions) override { \ + xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ + return HLO; \ + } \ + }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index a5b870f8dbf..6653944a911 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -57,8 +57,8 @@ class XlaBinaryOp : public XlaOpKernel { // in the XLA documentation. virtual xla::XlaOp Computation( XlaOpKernelContext* ctx, const xla::XlaOp& lhs, - const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, - const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, + const absl::Span& lhs_shape, const xla::XlaOp& rhs, + const absl::Span& rhs_shape, const BCast& broadcast_helper, const std::vector& extend_dimensions) = 0; void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 70c3eaf66bb..49c12fc2320 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -29,7 +29,7 @@ namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, - gtl::ArraySlice other_dims, + absl::Span other_dims, xla::PrimitiveType element_type) { xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: @@ -177,7 +177,7 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); - tensorflow::gtl::ArraySlice other_dims(dims); + absl::Span other_dims(dims); other_dims.remove_suffix(1); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 8e071bf0b7a..d9a0257b70b 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -78,7 +78,7 @@ struct ResizeConvolutionDims { std::vector stride; }; ResizeConvolutionDims ComputeResizeConvolutionParameters( - gtl::ArraySlice in_size, gtl::ArraySlice out_size, + absl::Span in_size, absl::Span out_size, bool align_corners) { CHECK_EQ(in_size.size(), out_size.size()); int num_spatial_dims = in_size.size(); @@ -147,7 +147,7 @@ std::vector Make1DKernel(int64 n) { const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, - gtl::ArraySlice kernel_size, + absl::Span kernel_size, int64 channels) { xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); @@ -165,7 +165,7 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, - gtl::ArraySlice kernel_size, + absl::Span kernel_size, int64 channels, int64 dim) { xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2da9340625d..afd59868467 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -155,7 +155,8 @@ class RandomShuffleOp : public XlaOpKernel { xla::XlaOp indices = xla::Iota(builder, xla::S32, n); // Swap the indices at i and swaps[i]. - auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto swap_body_fn = [&](xla::XlaOp i, + absl::Span loop_vars, xla::XlaBuilder* builder) -> xla::StatusOr> { auto swaps = loop_vars[0]; diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index d9578eca5bf..9e4c57c9bf7 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -66,7 +66,7 @@ class SelectOp : public XlaOpKernel { // XLA. It seems we have to broadcast on the left and then Reshape // to get the dimensions in the right order. const auto dim_sizes = then_shape.dim_sizes(); - gtl::ArraySlice bdims = dim_sizes; + absl::Span bdims = dim_sizes; bdims.remove_prefix(1); cond_handle = xla::Broadcast(cond_handle, bdims); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 7327258c31f..b7b4f3a5465 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -23,7 +23,7 @@ namespace { void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, DataType input_dtype, const TensorShape& input_tensor_shape, - gtl::ArraySlice block_shape, + absl::Span block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); const gtl::InlinedVector input_shape = @@ -34,7 +34,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, ctx, input_rank >= 1 + block_rank, errors::InvalidArgument("input rank should be >= ", 1 + block_rank, " instead of ", input_rank)); - gtl::ArraySlice remainder_shape(input_shape); + absl::Span remainder_shape(input_shape); remainder_shape.remove_prefix(1 + block_rank); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index be1814d8e3a..bb114d1aedd 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -122,7 +122,7 @@ Status GetTensorArrayShape(const XlaResource* resource, // relevant slice of 'operand'. xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, - const gtl::ArraySlice& update_dims, + absl::Span update_dims, const xla::XlaOp& start_indices) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); xla::XlaOp sum = xla::Add(current, update); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index ff3de75ad2e..c50a8de33e9 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -64,7 +64,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, xla::XlaOp l = xla::ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, xla::XlaBuilder* body_builder) -> xla::StatusOr> { xla::Shape col_shape; diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index df2504a0f97..0a140fa93ca 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -65,9 +65,9 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice batch_dims, - const int64 m, xla::XlaOp* v, xla::XlaOp* tau, - xla::XlaOp* beta) { +xla::Status House(xla::XlaOp x, xla::XlaOp k, + absl::Span batch_dims, const int64 m, + xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { xla::XlaBuilder* const builder = x.builder(); TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); const xla::PrimitiveType type = x_shape.element_type(); @@ -173,7 +173,7 @@ xla::StatusOr QRBlock( std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); auto qr_body_fn = - [&](xla::XlaOp j, gtl::ArraySlice values, + [&](xla::XlaOp j, absl::Span values, xla::XlaBuilder* builder) -> xla::StatusOr> { auto a = values[0]; auto vs = values[1]; @@ -255,7 +255,7 @@ xla::StatusOr QRBlock( // There is no need to return Y since at termination of the loop it is equal to // vs. xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, gtl::ArraySlice batch_dims, xla::XlaOp vs, + xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, xla::XlaOp taus, int64 m, int64 n, xla::PrecisionConfigProto::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); @@ -263,7 +263,7 @@ xla::StatusOr ComputeWYRepresentation( int64 n_index = batch_dims.size() + 1; auto body_fn = - [&](xla::XlaOp j, gtl::ArraySlice values, + [&](xla::XlaOp j, absl::Span values, xla::XlaBuilder* builder) -> xla::StatusOr> { auto w = values[0]; auto y = values[1]; diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index bafe5099f2d..89df6683d9f 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -40,9 +40,9 @@ xla::StatusOr XlaScatter( TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); - gtl::ArraySlice indices_dims = + absl::Span indices_dims = xla::AsInt64Slice(indices_shape.dimensions()); - gtl::ArraySlice buffer_dims = + absl::Span buffer_dims = xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains @@ -107,7 +107,7 @@ xla::StatusOr XlaScatter( // index = dynamic-slice(indices, i) // update = dynamic-slice(updates, i) // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, xla::XlaBuilder* body_builder) { auto indices = loop_vars[0]; auto updates = loop_vars[1]; diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 24e5dbbc6d1..c2678485247 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -113,8 +113,8 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, - gtl::ArraySlice end) { +xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, + absl::Span end) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_RET_CHECK(start.size() == end.size()); @@ -144,8 +144,8 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, }); } -std::vector ConcatVectors(gtl::ArraySlice xs, - gtl::ArraySlice ys) { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { std::vector output(xs.size() + ys.size()); std::copy(xs.begin(), xs.end(), output.begin()); std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); @@ -153,8 +153,8 @@ std::vector ConcatVectors(gtl::ArraySlice xs, } xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - gtl::ArraySlice starts, - gtl::ArraySlice sizes) { + absl::Span starts, + absl::Span sizes) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); @@ -173,7 +173,7 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, } xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start) { + absl::Span start) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. @@ -191,7 +191,7 @@ xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, } xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start) { + absl::Span start) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); @@ -206,13 +206,13 @@ xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, } xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice starts) { + absl::Span starts) { auto padded_starts = PrependZerosInMajorDims(x, starts); return xla::DynamicUpdateSlice(x, update, padded_starts); } xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts) { + absl::Span starts) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index b4905c95282..bb8ab8b407c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -31,7 +31,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts); + absl::Span starts); // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. @@ -41,33 +41,33 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts); + absl::Span starts); // Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, - gtl::ArraySlice end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, + absl::Span end); // Returns the concatenation of `xs` and `ys`. -std::vector ConcatVectors(gtl::ArraySlice xs, - gtl::ArraySlice ys); +std::vector ConcatVectors(absl::Span xs, + absl::Span ys); // Performs a dynamic slice in the minor dimensions of a Tensor. xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - gtl::ArraySlice starts, - gtl::ArraySlice sizes); + absl::Span starts, + absl::Span sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start); + absl::Span start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start); + absl::Span start); xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice starts); + absl::Span starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::XlaOp TransposeInMinorDims(xla::XlaOp x); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index d64394f1401..5300e2c878b 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -24,7 +24,7 @@ namespace tensorflow { xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, StringPiece name, xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; @@ -84,15 +84,15 @@ xla::StatusOr> XlaWhileLoop( xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, StringPiece name, xla::XlaBuilder* builder) { auto while_cond_fn = - [&](gtl::ArraySlice values, + [&](absl::Span values, xla::XlaBuilder* cond_builder) -> xla::StatusOr { return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice values, + auto while_body_fn = [&](absl::Span values, xla::XlaBuilder* body_builder) -> xla::StatusOr> { xla::XlaOp iteration = values[0]; diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 9493b1f109b..ebaaaee6363 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -29,14 +29,14 @@ namespace tensorflow { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(gtl::ArraySlice, +typedef std::function(absl::Span, xla::XlaBuilder*)> LoopConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. typedef std::function>( - gtl::ArraySlice, xla::XlaBuilder*)> + absl::Span, xla::XlaBuilder*)> LoopBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -50,7 +50,7 @@ typedef std::function>( xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, StringPiece name, xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. @@ -59,13 +59,13 @@ xla::StatusOr> XlaWhileLoop( // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. typedef std::function>( - xla::XlaOp, gtl::ArraySlice, xla::XlaBuilder*)> + xla::XlaOp, absl::Span, xla::XlaBuilder*)> ForEachIndexBodyFunction; xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, StringPiece name, xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 77da1bf29ce..20103ec3ae0 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -49,9 +49,8 @@ Status HostTensorToMutableBorrowingLiteral( return Status::OK(); } -Status HostTensorsToBorrowingLiteralTuple( - tensorflow::gtl::ArraySlice host_tensors, - xla::BorrowingLiteral* literal) { +Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, + xla::BorrowingLiteral* literal) { std::vector buf_ptrs; buf_ptrs.reserve(host_tensors.size()); std::vector tensor_shapes(host_tensors.size()); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 09d6fa81166..b4e317cee1b 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -43,9 +43,8 @@ Status HostTensorToMutableBorrowingLiteral( // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. -Status HostTensorsToBorrowingLiteralTuple( - tensorflow::gtl::ArraySlice host_tensors, - xla::BorrowingLiteral* literal); +Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, + xla::BorrowingLiteral* literal); // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index a3404c2b3df..7dc16b5a467 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -28,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector int64_values = {1, 2, 3}; std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); + xla::LiteralUtil::CreateR1(absl::Span(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -49,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector int32_values = {10, 11}; std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); + xla::LiteralUtil::CreateR1(absl::Span(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index aa2a521d984..0c300c282e9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -835,8 +835,8 @@ Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, namespace { -void SetTransfer(const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes, +void SetTransfer(const string& key, absl::Span types, + absl::Span shapes, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); CHECK(types.size() == shapes.size()); @@ -850,8 +850,8 @@ void SetTransfer(const string& key, gtl::ArraySlice types, } // namespace Status XlaCompiler::SetDeviceToHostMetadata( - const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes) { + const string& key, absl::Span types, + absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetDeviceToHostMetadata with key ", key); @@ -877,8 +877,8 @@ Status XlaCompiler::GetDeviceToHostShapes( } Status XlaCompiler::SetHostToDeviceMetadata( - const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes) { + const string& key, absl::Span types, + absl::Span shapes) { if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetHostToDeviceMetadata with key ", key); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 9e2c64fd421..8f4a9858ed6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -351,8 +351,8 @@ class XlaCompiler { // Sets the shapes and types for the device to host transfer associated with // 'key'. Status SetDeviceToHostMetadata(const string& key, - gtl::ArraySlice types, - gtl::ArraySlice shapes); + absl::Span types, + absl::Span shapes); // Gets the shapes the device to host transfer associated with 'key'. Status GetDeviceToHostShapes(const string& key, @@ -361,8 +361,8 @@ class XlaCompiler { // Sets the shapes and types for the host to device transfer associated with // 'key'. Status SetHostToDeviceMetadata(const string& key, - gtl::ArraySlice types, - gtl::ArraySlice shapes); + absl::Span types, + absl::Span shapes); // In order to avoid deadlocks from dependencies in host computations, it can // be necessary to enforce a partial order on the execution of HostCompute diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 8efb3d55c88..ea10ca64637 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -119,7 +119,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, } /* static */ Status XlaHelpers::ReshapeLiteral( - const xla::Literal& input, gtl::ArraySlice dimensions, + const xla::Literal& input, absl::Span dimensions, xla::Literal* output) { if (xla::ShapeUtil::IsTuple(input.shape())) { return errors::InvalidArgument("ReshapeLiteral does not support tuples."); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index e6522157a53..e22b1f0f377 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -50,7 +50,7 @@ class XlaHelpers { // Reshapes literal 'input' to have 'shape'. Both the original shape and // 'shape' must contain the same number of elements. static Status ReshapeLiteral(const xla::Literal& input, - gtl::ArraySlice shape, + absl::Span shape, xla::Literal* output); // Returns the argmax of `input` along `axis`. `output_type` is the type to diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 9e8f5f2a1ad..1499c99ed15 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -119,7 +119,7 @@ Status XlaOpKernelContext::ConstantInput(StringPiece name, } Status XlaOpKernelContext::ConstantInputReshaped( - int index, gtl::ArraySlice new_dims, + int index, absl::Span new_dims, xla::Literal* constant_literal) { const Tensor& tensor = context_->input(index); TensorShape new_shape(new_dims); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 3e26ba4f015..45cfa7da740 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -113,7 +113,7 @@ class XlaOpKernelContext { // cannot be evaluated, e.g., because it depends on unbound parameters, // returns a non-Ok status. If InputShape(index).num_elements() != // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped(int index, gtl::ArraySlice new_shape, + Status ConstantInputReshaped(int index, absl::Span new_dims, xla::Literal* constant_literal); // Converts a constant scalar int32 or int64 tensor into an int64. diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 2f3a4cd3b57..dae2d956ca6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -105,7 +105,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; /* static */ void XlaOpRegistry::RegisterBackend( const string& compilation_device_name, - gtl::ArraySlice supported_types, BackendOpFilter op_filter) { + absl::Span supported_types, BackendOpFilter op_filter) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto result = registry.backends_.emplace(compilation_device_name, Backend()); @@ -382,7 +382,7 @@ XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( - gtl::ArraySlice devices) { + absl::Span devices) { registration_->has_device_whitelist = true; for (StringPiece device : devices) { registration_->device_whitelist.emplace(device); @@ -415,7 +415,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, gtl::ArraySlice allowed) { + StringPiece attr_name, absl::Span allowed) { std::set& types = registration_->type_constraints[string(attr_name)]; for (DataType t : allowed) { @@ -452,7 +452,7 @@ XlaOpRegistrar::XlaOpRegistrar( } XlaBackendRegistrar::XlaBackendRegistrar( - StringPiece name, gtl::ArraySlice types, + StringPiece name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); registry.RegisterBackend(string(name), types, op_filter); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 6ce0e2580b1..c640842dc0d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -94,7 +94,7 @@ class XlaOpRegistry { // the device; it may optionally modify the KernelDef. typedef bool (*BackendOpFilter)(KernelDef* kdef); static void RegisterBackend(const string& compilation_device_name, - gtl::ArraySlice supported_types, + absl::Span supported_types, BackendOpFilter op_filter); // Returns the names of the registered backends. @@ -236,7 +236,7 @@ class XlaOpRegistrationBuilder { // Specifies a whitelist of devices on which the operator may run. XlaOpRegistrationBuilder& Device(StringPiece devices); - XlaOpRegistrationBuilder& Device(gtl::ArraySlice devices); + XlaOpRegistrationBuilder& Device(absl::Span devices); // Specifies a type constraint for a type variable attribute. Each constraint // specifies the set of types that the type variable may assume. @@ -244,7 +244,7 @@ class XlaOpRegistrationBuilder { DataType allowed); XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, - gtl::ArraySlice allowed); + absl::Span allowed); // Specifies that a dummy copy of this operator should not be registered on // XLA_* devices, but may be used during compilation. @@ -288,7 +288,7 @@ class XlaOpRegistrar { class XlaBackendRegistrar { public: - XlaBackendRegistrar(StringPiece name, gtl::ArraySlice types, + XlaBackendRegistrar(StringPiece name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter = nullptr); }; diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index c8e483712ef..0ec04f42a8f 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -97,12 +97,11 @@ class Array { using value_type = T; // Creates a new array with the specified dimensions. - explicit Array(tensorflow::gtl::ArraySlice sizes) - : Array(sizes, T()) {} + explicit Array(absl::Span sizes) : Array(sizes, T()) {} // Creates a new array with the specified dimensions and specified value for // every cell. - Array(tensorflow::gtl::ArraySlice sizes, T value) + Array(absl::Span sizes, T value) : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { Fill(value); } @@ -301,7 +300,7 @@ class Array { // Invokes a callback with the (indices, value_ptr) for each cell in the // array. - void Each(std::function, T*)> f) { + void Each(std::function, T*)> f) { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { f(index, &values_[i]); @@ -309,8 +308,7 @@ class Array { } // Invokes a callback with the (indices, value) for each cell in the array. - void Each( - std::function, T)> f) const { + void Each(std::function, T)> f) const { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { f(index, values_[i]); @@ -320,8 +318,7 @@ class Array { // Invokes a callback with the (indices, value_ptr) for each cell in the // array. If a callback returns a non-OK status, returns that else returns // Status::OK(). - Status EachStatus( - std::function, T*)> f) { + Status EachStatus(std::function, T*)> f) { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { Status s = f(index, &values_[i]); @@ -335,8 +332,7 @@ class Array { // Invokes a callback with the (indices, value) for each cell in the array. // If a callback returns a non-OK status, returns that else returns // Status::OK(). - Status EachStatus( - std::function, T)> f) const { + Status EachStatus(std::function, T)> f) const { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { Status s = f(index, values_[i]); @@ -377,13 +373,13 @@ class Array { // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. - const T& operator()(tensorflow::gtl::ArraySlice indexes) const { + const T& operator()(absl::Span indexes) const { return values_[calculate_index(indexes)]; } // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. - T& operator()(tensorflow::gtl::ArraySlice indexes) { + T& operator()(absl::Span indexes) { return values_[calculate_index(indexes)]; } @@ -438,8 +434,8 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } // Performs the equivalent of a slice operation on this array. - Array Slice(tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits) const { + Array Slice(absl::Span starts, + absl::Span limits) const { CHECK_EQ(starts.size(), num_dimensions()); CHECK_EQ(limits.size(), num_dimensions()); @@ -464,7 +460,7 @@ class Array { // Performs the equivalent of a DynamicUpdateSlice in-place on this array. void UpdateSlice(const Array& from, - tensorflow::gtl::ArraySlice start_indices) { + absl::Span start_indices) { CHECK_EQ(from.num_dimensions(), num_dimensions()); std::vector limit_indices; std::transform(start_indices.begin(), start_indices.end(), @@ -484,7 +480,7 @@ class Array { // Performs an in-place reshape, modifying the dimensions but not the // underlying data. - void Reshape(tensorflow::gtl::ArraySlice new_dimensions) { + void Reshape(absl::Span new_dimensions) { int64 old_num_elements = num_elements(); sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); CHECK_EQ(num_elements(), old_num_elements); diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc index 927733ea1ea..3ab67b128f8 100644 --- a/tensorflow/compiler/xla/array4d_test.cc +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -27,8 +27,7 @@ namespace { // Given an Array4D and a 4-tuple index, computes the linear index into the // array idx represents. template -int64 Array4DLinearIndex(const Array4D& arr, - tensorflow::gtl::ArraySlice idx) { +int64 Array4DLinearIndex(const Array4D& arr, absl::Span idx) { EXPECT_EQ(4, idx.size()); return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() + idx[0] * arr.n2() * arr.n3() * arr.n4()); @@ -51,9 +50,8 @@ TEST(Array4dTest, FillCtor) { EXPECT_EQ(fullof7.n3(), 4); EXPECT_EQ(fullof7.n4(), 5); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 7); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 7); }); } TEST(Array4dTest, ContainerCtor) { @@ -69,7 +67,7 @@ TEST(Array4dTest, ContainerCtor) { EXPECT_EQ(arr.n3(), 4); EXPECT_EQ(arr.n4(), 5); - arr.Each([&arr](tensorflow::gtl::ArraySlice idx, int* cell) { + arr.Each([&arr](absl::Span idx, int* cell) { EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx)); }); } @@ -129,21 +127,19 @@ TEST(Array3dTest, InitializerListCtorHalf) { TEST(Array4dTest, Fill) { Array4D fullof7(2, 3, 4, 5, 7); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 7); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 7); }); fullof7.Fill(11); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 11); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 11); }); } TEST(Array4dTest, FillWithMultiples) { Array4D arr(2, 3, 4, 5); arr.FillWithMultiples(2.0f); - arr.Each([&arr](tensorflow::gtl::ArraySlice idx, float* cell) { + arr.Each([&arr](absl::Span idx, float* cell) { EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx)); }); } diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index e8356c9832d..2d0ac98bd4e 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -163,7 +163,7 @@ TEST(ArrayTest, Each) { arr.FillWithMultiples(1); int64 each_count = 0, each_sum = 0; - arr.Each([&](tensorflow::gtl::ArraySlice idx, int cell) { + arr.Each([&](absl::Span idx, int cell) { int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2]; EXPECT_EQ(lin_idx, cell); each_count++; diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 1fdf8f6260d..8818f813127 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -163,8 +163,7 @@ Status Client::ResetDevice() { } StatusOr> Client::ExecuteAndTransfer( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { TF_ASSIGN_OR_RETURN( @@ -212,8 +211,7 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { } StatusOr> Client::Execute( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { ExecuteGraphRequest request; @@ -252,7 +250,7 @@ StatusOr> Client::Execute( } StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { ExecuteGraphParallelRequest request; for (const XlaComputationInstance& computation : computations) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index be50cebfcc0..940ddbedd1e 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -53,7 +53,7 @@ class Client { // will be filled with profile data from the execution. StatusOr> Execute( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); @@ -82,7 +82,7 @@ class Client { // from each computation. // StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Requests device_count device handles available on the target. The returned // device handles are used to specify the devices to execute the computations @@ -134,7 +134,7 @@ class Client { // Execute() and Transfer(). StatusOr> ExecuteAndTransfer( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 040344c9a65..a6c58cb1757 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -23,7 +23,7 @@ namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector service_instances; diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index d0c83cbfccb..9e3ed237349 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -52,7 +52,7 @@ class CompileOnlyClient : public Client { // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata = nullptr); diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 888d2f28ebb..93334db88bc 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -86,7 +86,7 @@ class ExecutableBuildOptions { void add_disabled_hlo_pass(absl::string_view pass_name) { disabled_hlo_passes_.push_back(std::string(pass_name)); } - const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + const absl::Span disabled_hlo_passes() const { return disabled_hlo_passes_; } diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index e569610b855..d3d7edb42a3 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -69,8 +69,7 @@ std::array kErfUCoefficient = { // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients) { +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { XlaOp poly = ScalarLike(x, 0.0); for (float c : coefficients) { poly = poly * x + ScalarLike(x, c); diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 13db2325569..a6cafd42077 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -34,8 +34,7 @@ XlaOp Reciprocal(XlaOp operand); // Evaluates a polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients); +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); // Computes an approximation of the error function complement (1 - erf(x)). XlaOp Erfc(XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc index 7f90d6c1972..a7a89481c92 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -39,7 +39,7 @@ XlaOp GetMatrixDiagonal(XlaOp x) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - tensorflow::gtl::ArraySlice major_dims = + absl::Span major_dims = AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); @@ -66,7 +66,7 @@ XlaOp Triangle(XlaOp x, bool lower) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - tensorflow::gtl::ArraySlice major_dims = + absl::Span major_dims = AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc index 3ae9ae36f65..1979c867a4c 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.cc +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -26,11 +26,9 @@ namespace { // element of an image by the count of elements that contributed to that // element during pooling. XlaOp AvgPoolDivideByCountWithGeneralPadding( - XlaOp sums, PrimitiveType dtype, - tensorflow::gtl::ArraySlice input_shape, - tensorflow::gtl::ArraySlice> spatial_padding, - tensorflow::gtl::ArraySlice ksize, - tensorflow::gtl::ArraySlice stride, + XlaOp sums, PrimitiveType dtype, absl::Span input_shape, + absl::Span> spatial_padding, + absl::Span ksize, absl::Span stride, const TensorFormat& data_format) { // The padding shouldn't be included in the counts. We use another // ReduceWindow to find the right counts. @@ -73,8 +71,8 @@ XlaOp AvgPoolDivideByCountWithGeneralPadding( // Sums all elements in the window specified by 'kernel_size' and 'stride'. XlaOp ComputeSums(XlaOp operand, XlaOp init_value, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, + absl::Span kernel_size, + absl::Span stride, const TensorFormat& data_format) { XlaBuilder* b = operand.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { @@ -89,8 +87,8 @@ XlaOp ComputeSums(XlaOp operand, XlaOp init_value, // Creates a padding configuration out of spatial padding values. PaddingConfig MakeSpatialPaddingConfig( - tensorflow::gtl::ArraySlice> spatial_padding, - int num_spatial_dims, tensorflow::gtl::ArraySlice stride, + absl::Span> spatial_padding, + int num_spatial_dims, absl::Span stride, const TensorFormat& data_format) { PaddingConfig padding_config; for (int i = 0; i < 2 + num_spatial_dims; ++i) { @@ -107,13 +105,12 @@ PaddingConfig MakeSpatialPaddingConfig( return padding_config; } -XlaOp AvgPoolDivideByCount( - XlaOp pooled, tensorflow::gtl::ArraySlice input_size, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - PrimitiveType dtype, const TensorFormat& data_format, - bool counts_include_padding) { +XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span input_size, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + PrimitiveType dtype, const TensorFormat& data_format, + bool counts_include_padding) { if (counts_include_padding) { // If counts include padding, all windows have the same number of elements // contributing to each average. Divide by the window size everywhere to get @@ -133,8 +130,8 @@ XlaOp AvgPoolDivideByCount( } // namespace -XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format) { XlaBuilder* b = operand.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { @@ -147,9 +144,9 @@ XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, }); } -XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> padding, +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, const TensorFormat& data_format, const bool counts_include_padding) { XlaBuilder* b = operand.builder(); @@ -173,9 +170,8 @@ XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, } std::vector> MakeSpatialPadding( - tensorflow::gtl::ArraySlice input_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format) { const int num_spatial_dims = kernel_size.size() - 2; std::vector input_spatial_dimensions; @@ -193,12 +189,12 @@ std::vector> MakeSpatialPadding( stride_spatial_dimensions, padding); } -XlaOp AvgPoolGrad( - XlaOp out_backprop, tensorflow::gtl::ArraySlice gradients_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> spatial_padding, - const TensorFormat& data_format, const bool counts_include_padding) { +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, + const bool counts_include_padding) { XlaBuilder* b = out_backprop.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { const int num_dims = kernel_size.size(); diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h index 291c711a005..5c0054857d0 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.h +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -25,7 +25,7 @@ namespace xla { class TensorFormat { public: TensorFormat(int batch_dimension, int feature_dimension, - tensorflow::gtl::ArraySlice spatial_dimensions) + absl::Span spatial_dimensions) : batch_dimension_(batch_dimension), feature_dimension_(feature_dimension), spatial_dimensions_(spatial_dimensions.begin(), @@ -49,32 +49,31 @@ class TensorFormat { }; // Computes the max pool of 'operand'. -XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format); // Computes the average pool of 'operand'. -XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> padding, +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, const TensorFormat& data_format, const bool counts_include_padding); // Returns the list of low and high padding elements in each spatial dimension // for the given 'padding' specification. std::vector> MakeSpatialPadding( - tensorflow::gtl::ArraySlice input_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format); // Computes the average pool gradient. -XlaOp AvgPoolGrad( - XlaOp out_backprop, tensorflow::gtl::ArraySlice gradients_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> spatial_padding, - const TensorFormat& data_format, const bool counts_include_padding); +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, + const bool counts_include_padding); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc index 18900479189..30adb9b1ad7 100644 --- a/tensorflow/compiler/xla/client/lib/pooling_test.cc +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -32,8 +32,8 @@ TensorFormat MakeNCHWFormat(int num_spatial_dims) { } std::vector> MakeGeneralPadding( - XlaOp input, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + XlaOp input, absl::Span kernel_size, + absl::Span stride, Padding padding, const xla::TensorFormat& data_format) { XlaBuilder* b = input.builder(); Shape operand_shape = b->GetShape(input).ValueOrDie(); @@ -46,7 +46,7 @@ std::vector> MakeGeneralPadding( // Add singleton batch and feature dimensions to spatial dimensions, according // to 'data_format' specification. std::vector ExpandWithBatchAndFeatureDimensions( - tensorflow::gtl::ArraySlice spatial_dim_sizes, + absl::Span spatial_dim_sizes, const xla::TensorFormat& data_format) { const int num_spatial_dims = spatial_dim_sizes.size(); std::vector tensor_sizes(num_spatial_dims + 2, 1); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index db7a8fc0475..4402ba8762c 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -51,7 +51,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, } Status LocalExecutable::ValidateExecutionOptions( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& computation_layout = executable_->module_config().entry_computation_layout(); @@ -140,7 +140,7 @@ Status LocalExecutable::ValidateExecutionOptions( } StatusOr LocalExecutable::Run( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, ExecutableRunOptions run_options) { TF_RETURN_IF_ERROR( ValidateExecutionOptions(arguments, run_options, *backend_)); @@ -177,7 +177,7 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, - const tensorflow::gtl::ArraySlice arguments) { + const absl::Span arguments) { executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); @@ -191,7 +191,7 @@ StatusOr LocalExecutable::ExecuteAndDump( } Status LocalExecutable::RecordArguments( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { @@ -245,7 +245,7 @@ Backend* LocalClient::mutable_backend() { StatusOr> LocalClient::Compile( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& options) { ExecutableBuildOptions updated_options = options; if (options.device_ordinal() == -1) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index ae238092617..78a1310c551 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -40,7 +40,7 @@ class LocalExecutable { // Run the compiled computation with the given arguments and options and // return the result. StatusOr Run( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, ExecutableRunOptions run_options); // Return the options used to build the executable. @@ -63,7 +63,7 @@ class LocalExecutable { // The given ExecutableRunOptions override any values from legacy_flags // (TF_XLA_FLAGS environment variable). Status ValidateExecutionOptions( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to @@ -73,13 +73,12 @@ class LocalExecutable { // (TF_XLA_FLAGS environment variable). StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, - const tensorflow::gtl::ArraySlice arguments); + const absl::Span arguments); // Records the arguments used to invoke the computation in a SessionModule // proto. - Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - HloSnapshot* hlo_snapshot); + Status RecordArguments(const absl::Span arguments, + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); @@ -120,7 +119,7 @@ class LocalClient : public Client { // (TF_XLA_FLAGS environment variable). StatusOr> Compile( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& options); // Copy the literal data to the device with the given ordinal and return as a diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index ed4dc8e9f6d..992b13139c4 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -23,10 +23,9 @@ limitations under the License. namespace xla { -Status ValidatePaddingValues( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides) { +Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides) { bool ok = input_dimensions.size() == window_dimensions.size() && input_dimensions.size() == window_strides.size(); if (!ok) { @@ -40,9 +39,9 @@ Status ValidatePaddingValues( } std::vector> MakePadding( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions, window_strides)); std::vector> low_high_padding; diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h index e23b0b3a90a..d5bd26cafbf 100644 --- a/tensorflow/compiler/xla/client/padding.h +++ b/tensorflow/compiler/xla/client/padding.h @@ -41,10 +41,9 @@ enum class Padding { // Validates that the slices are acceptable for determining padding -- this can // be used to check the preconditions of MakePadding below to produce an error // message that can be returned to the user. -Status ValidatePaddingValues( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides); +Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides); // Returns the padding needed for the base area, given the base area dimensions, // window dimensions, strides, and the type of padding. @@ -58,9 +57,9 @@ Status ValidatePaddingValues( // window_dimensions, and strides must match, which is equal to the number // of elements in the result vector. std::vector> MakePadding( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 531b8dd66b5..e639028ccda 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -90,7 +90,7 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { } StatusOr> XlaBuilder::GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const { + absl::Span operands) const { std::vector operand_shapes; for (const XlaOp& operand : operands) { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); @@ -291,7 +291,7 @@ StatusOr XlaBuilder::Build(int64 root_id) { StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; @@ -352,9 +352,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { }); } -XlaOp XlaBuilder::BinaryOp( - HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -448,12 +447,12 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } @@ -480,7 +479,7 @@ XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) { } XlaOp XlaBuilder::Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; @@ -515,8 +514,8 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, }); } -XlaOp XlaBuilder::Broadcast( - const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { +XlaOp XlaBuilder::Broadcast(const XlaOp& operand, + absl::Span broadcast_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -541,7 +540,7 @@ XlaOp XlaBuilder::Broadcast( XlaOp XlaBuilder::BroadcastInDim( const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions) { + const absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { return InDimBroadcast(shape, operand, broadcast_dimensions); }); @@ -556,9 +555,9 @@ StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { } XlaOp XlaBuilder::Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -593,7 +592,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, } XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -631,7 +630,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, }); } -XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, +XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -671,8 +670,8 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, } XlaOp XlaBuilder::Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { + absl::Span dimensions, + absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, @@ -686,7 +685,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, } XlaOp XlaBuilder::Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes) { + absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); @@ -696,7 +695,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, } XlaOp XlaBuilder::Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus @@ -706,8 +705,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, // Out-of-order collapse is not supported. // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dimensions.size(); ++i) { + for (absl::Span::size_type i = 1; i < dimensions.size(); ++i) { if (dimensions[i] - 1 != dimensions[i - 1]) { return InvalidArgument( "Collapsed dimensions are not in consecutive order."); @@ -758,7 +756,7 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, }); } -XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { +XlaOp XlaBuilder::Tuple(absl::Span elements) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; @@ -792,32 +790,32 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { } XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); } @@ -899,8 +897,8 @@ Status XlaBuilder::VerifyConvolution( } XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, int64 feature_group_count, + absl::Span window_strides, Padding padding, + int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, @@ -909,9 +907,8 @@ XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { return ConvGeneral(lhs, rhs, window_strides, padding, @@ -920,9 +917,8 @@ XlaOp XlaBuilder::ConvWithGeneralPadding( } XlaOp XlaBuilder::ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -957,9 +953,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( } XlaOp XlaBuilder::ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { @@ -969,11 +964,9 @@ XlaOp XlaBuilder::ConvGeneral( } XlaOp XlaBuilder::ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { @@ -1013,11 +1006,11 @@ XlaOp XlaBuilder::ConvGeneralDilated( } StatusOr XlaBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) const { const auto verify_size = [&](const size_t x, const char* x_name) { if (x == 0 || x == window_dimensions.size()) { return Status::OK(); @@ -1067,7 +1060,7 @@ StatusOr XlaBuilder::MakeWindow( } XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { + const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1276,7 +1269,7 @@ XlaOp XlaBuilder::CreateToken() { }); } -XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { +XlaOp XlaBuilder::AfterAll(absl::Span tokens) { return ReportErrorOrReturn([&]() -> StatusOr { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); @@ -1288,7 +1281,7 @@ XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { } XlaOp XlaBuilder::CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1304,9 +1297,8 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, }); } -XlaOp XlaBuilder::Complex( - const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); } @@ -1315,42 +1307,42 @@ XlaOp XlaBuilder::Conj(const XlaOp& operand) { } XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); } @@ -1358,22 +1350,21 @@ XlaOp XlaBuilder::Not(const XlaOp& operand) { return UnaryOp(HloOpcode::kNot, operand); } -XlaOp XlaBuilder::ShiftLeft( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ShiftRightArithmetic( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ShiftRightLogical( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, broadcast_dimensions); } @@ -1382,9 +1373,8 @@ XlaOp XlaBuilder::Abs(const XlaOp& operand) { return UnaryOp(HloOpcode::kAbs, operand); } -XlaOp XlaBuilder::Atan2( - const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); } @@ -1449,7 +1439,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { } XlaOp XlaBuilder::Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation) { + absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1464,7 +1454,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, } XlaOp XlaBuilder::Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1506,7 +1496,7 @@ XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional values, } XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); } @@ -1544,10 +1534,10 @@ XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, return TernaryOp(HloOpcode::kClamp, min, operand, max); } -XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, +XlaOp XlaBuilder::Map(absl::Span operands, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { + absl::Span dimensions, + absl::Span static_operands) { return ReportErrorOrReturn([&]() -> StatusOr { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); @@ -1588,7 +1578,7 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, } XlaOp XlaBuilder::RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, + absl::Span parameters, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1649,7 +1639,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1729,20 +1719,18 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, }); } -XlaOp XlaBuilder::Reduce( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return Reduce(tensorflow::gtl::ArraySlice({operand}), - tensorflow::gtl::ArraySlice({init_value}), computation, +XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + return Reduce(absl::Span({operand}), + absl::Span({init_value}), computation, dimensions_to_reduce); } -XlaOp XlaBuilder::Reduce( - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { +XlaOp XlaBuilder::Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1785,11 +1773,11 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, }); } -XlaOp XlaBuilder::ReduceWindow( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { +XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1810,9 +1798,9 @@ XlaOp XlaBuilder::ReduceWindow( XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1907,8 +1895,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, } XlaOp XlaBuilder::CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups) { + const XlaOp& operand, absl::Span replica_groups) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); @@ -1923,7 +1910,7 @@ XlaOp XlaBuilder::CrossReplicaSum( XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups, + absl::Span replica_groups, const absl::optional& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -2023,12 +2010,13 @@ XlaOp XlaBuilder::CollectivePermute( }); } -XlaOp XlaBuilder::SelectAndScatter( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { +XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand, + const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); return SelectAndScatterWithGeneralPadding( @@ -2041,11 +2029,10 @@ XlaOp XlaBuilder::SelectAndScatter( XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -2428,9 +2415,9 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { return Status::OK(); } -StatusOr XlaBuilder::AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { +StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode, + absl::Span operands) { TF_RETURN_IF_ERROR(first_error_); const int64 handle = instructions_.size(); @@ -2504,14 +2491,12 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { return builder->ConstantLiteral(literal); } -XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { +XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes) { return operand.builder()->Broadcast(operand, broadcast_sizes); } -XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions) { return operand.builder()->BroadcastInDim(operand, shape, broadcast_dimensions); } @@ -2521,26 +2506,22 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, return operand.builder()->Pad(operand, padding_value, padding_config); } -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { +XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes) { return operand.builder()->Reshape(operand, dimensions, new_sizes); } -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes) { +XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes) { return operand.builder()->Reshape(operand, new_sizes); } -XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { +XlaOp Collapse(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Collapse(operand, dimensions); } -XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { +XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return operand.builder()->Slice(operand, start_indices, limit_indices, strides); } @@ -2552,7 +2533,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, } XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } @@ -2561,8 +2542,7 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } -XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension) { return builder->ConcatInDim(operands, dimension); } @@ -2575,7 +2555,7 @@ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { return pred.builder()->Select(pred, on_true, on_false); } -XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements) { +XlaOp Tuple(XlaBuilder* builder, absl::Span elements) { return builder->Tuple(elements); } @@ -2584,32 +2564,32 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { } XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); } XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); } XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); } XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); } XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); } XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); } @@ -2626,7 +2606,7 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, + absl::Span window_strides, Padding padding, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, @@ -2634,9 +2614,8 @@ XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, } XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, @@ -2645,9 +2624,8 @@ XlaOp ConvWithGeneralPadding( } XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->ConvWithGeneralDimensions( @@ -2656,8 +2634,8 @@ XlaOp ConvWithGeneralDimensions( } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, + absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto) { @@ -2666,22 +2644,21 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, precision_config_proto); } -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision_config_proto); } XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) { + absl::Span fft_length) { return operand.builder()->Fft(operand, fft_type, fft_length); } @@ -2695,105 +2672,102 @@ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, } XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { return builder->Call(computation, operands); } XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { + absl::Span operands, const Shape& shape) { return builder->CustomCall(call_target_name, operands, shape); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return real.builder()->Complex(real, imag, broadcast_dimensions); } XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); } XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); } XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); } XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); } XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); } XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); } XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); } XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->And(lhs, rhs, broadcast_dimensions); } XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); } XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); } XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); } XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return operand.builder()->Reduce(operand, init_value, computation, dimensions_to_reduce); } // Reduces several arrays simultaneously among the provided dimensions, given // "computation" as a reduction operator. -XlaOp Reduce(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return builder->Reduce(operands, init_values, computation, dimensions_to_reduce); } @@ -2805,9 +2779,8 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding) { + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { return operand.builder()->ReduceWindow(operand, init_value, computation, window_dimensions, window_strides, padding); @@ -2816,22 +2789,21 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( operand, init_value, computation, window_dimensions, window_strides, padding); } -XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups) { +XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups) { return operand.builder()->CrossReplicaSum(operand, replica_groups); } XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups, + absl::Span replica_groups, const absl::optional& channel_id) { return operand.builder()->CrossReplicaSum(operand, computation, replica_groups, channel_id); @@ -2851,10 +2823,10 @@ XlaOp CollectivePermute( } XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { return operand.builder()->SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); @@ -2862,11 +2834,10 @@ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { return operand.builder()->SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); @@ -2875,7 +2846,7 @@ XlaOp SelectAndScatterWithGeneralPadding( XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return y.builder()->Atan2(y, x, broadcast_dimensions); } @@ -2908,7 +2879,7 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); } @@ -2926,12 +2897,11 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } -XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation) { +XlaOp Transpose(const XlaOp& operand, absl::Span permutation) { return operand.builder()->Transpose(operand, permutation); } -XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { +XlaOp Rev(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Rev(operand, dimensions); } @@ -2943,10 +2913,9 @@ XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); } -XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, absl::Span dimensions, + absl::Span static_operands) { return builder->Map(operands, computation, dimensions, static_operands); } @@ -2980,7 +2949,7 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return input.builder()->Gather(input, start_indices, dimension_numbers, slice_sizes); } @@ -3036,7 +3005,7 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); } -XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens) { +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens) { return builder->AfterAll(tokens); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index b9e651f2ae3..f83fef0cf2e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -294,7 +294,7 @@ class XlaBuilder { template XlaOp ConstantR0(NativeT value); template - XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(absl::Span values); XlaOp ConstantR1(const tensorflow::core::Bitmap& values); template XlaOp ConstantR2( @@ -336,7 +336,7 @@ class XlaBuilder { // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); // Performs in-dimension-style broadcast. // @@ -355,9 +355,8 @@ class XlaBuilder { // will generate output // [1 , 1] // [2 , 2] - XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); + XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config @@ -370,15 +369,13 @@ class XlaBuilder { // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); + XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -398,8 +395,7 @@ class XlaBuilder { // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. - XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -412,10 +408,9 @@ class XlaBuilder { // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice - XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to @@ -436,7 +431,7 @@ class XlaBuilder { // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -459,8 +454,7 @@ class XlaBuilder { // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. - XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); + XlaOp ConcatInDim(absl::Span operands, int64 dimension); // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. @@ -471,34 +465,34 @@ class XlaBuilder { XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); // Enqueues a tuple-creation instruction onto the computation. - XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + XlaOp Tuple(absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, @@ -513,7 +507,7 @@ class XlaBuilder { // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, + absl::Span window_strides, Padding padding, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); @@ -521,8 +515,8 @@ class XlaBuilder { // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, + absl::Span> padding, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); @@ -530,7 +524,7 @@ class XlaBuilder { // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, + absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); @@ -539,8 +533,8 @@ class XlaBuilder { // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, + absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); @@ -549,10 +543,10 @@ class XlaBuilder { // provided padding configuration, dilation factors and dimension numbers. XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); @@ -560,7 +554,7 @@ class XlaBuilder { // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. @@ -582,15 +576,14 @@ class XlaBuilder { // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Enqueues a custom call instruction onto the computation. // During code generation, a call instruction is emitted which targets a // symbol with the name |call_target_name|. The |operands| are passed to the // call instruction. |shape| is the resultant shape. XlaOp CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); + absl::Span operands, const Shape& shape); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -599,72 +592,70 @@ class XlaBuilder { // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Not(const XlaOp& operand); XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); + XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); // Reduces several arrays simultaneously among the provided dimensions, given // "computation" as a reduction operator. - XlaOp Reduce(tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, + XlaOp Reduce(absl::Span operands, + absl::Span init_values, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. @@ -674,25 +665,23 @@ class XlaBuilder { // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. - XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups = {}); + XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -714,7 +703,7 @@ class XlaBuilder { // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups = {}, + absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. @@ -731,8 +720,8 @@ class XlaBuilder { // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); @@ -741,18 +730,17 @@ class XlaBuilder { // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); @@ -799,7 +787,7 @@ class XlaBuilder { // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., // not Inf or NaN. Defined only for floating-point types. Returns an array of @@ -829,14 +817,12 @@ class XlaBuilder { XlaOp Neg(const XlaOp& operand); // Enqueues a transpose instruction onto the computation. - XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); + XlaOp Transpose(const XlaOp& operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). - XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. // If only keys are provided: @@ -861,10 +847,9 @@ class XlaBuilder { XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); // Enqueues a map instruction onto the computation. - XlaOp Map(tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); + XlaOp Map(absl::Span operands, const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. @@ -891,7 +876,7 @@ class XlaBuilder { // Enqueues a Gather node onto the computation. XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -919,7 +904,7 @@ class XlaBuilder { // Enqueues an AfterAll operation with no operands producing a token-shaped // value. - XlaOp AfterAll(tensorflow::gtl::ArraySlice tokens); + XlaOp AfterAll(absl::Span tokens); // Enqueues a Recv node onto the computation. The data comes from a Send // instruction that shares the same channel handle and its shape must @@ -966,9 +951,8 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - StatusOr AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands = {}); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands = {}); void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); @@ -982,19 +966,17 @@ class XlaBuilder { // broadcast_dimensions specifies which dimensions to use for broadcasting // when the operation is between tensors of different ranks. XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Internal helper method that does the building for an arbitrary ternary op. XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs); XlaOp RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); + absl::Span parameters, const Shape& shape); - StatusOr InDimBroadcast( - const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); + StatusOr InDimBroadcast(const Shape& shape, const XlaOp& operand, + absl::Span broadcast_dimensions); // Internal helper method that creates a sequence of instructions that // performs an explicit broadcast of the operand to the target shape. @@ -1010,7 +992,7 @@ class XlaBuilder { // Returns shapes for the operands. StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const; + absl::Span operands) const; // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful @@ -1027,12 +1009,11 @@ class XlaBuilder { // Helper function for creating a Window proto from user-supplied data. // Returns error if the user-supplied data was invalid. - StatusOr MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const; + StatusOr MakeWindow(absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) const; string name_; // Name to use for the built computation. @@ -1076,7 +1057,7 @@ class XlaBuilder { friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); template friend XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values); + absl::Span values); friend XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); template @@ -1116,193 +1097,183 @@ class XlaBuilder { friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); friend XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); friend XlaOp BroadcastInDim( const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); + const absl::Span broadcast_dimensions); friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config); - friend XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + friend XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); - friend XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); + friend XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); friend XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); friend XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - int64 dimension); + absl::Span operands, int64 dimension); friend void Trace(const string& tag, const XlaOp& operand); friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - friend XlaOp Tuple(XlaBuilder* builder, - tensorflow::gtl::ArraySlice elements); + friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfigProto* precision_config_proto); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_number, const PrecisionConfigProto* precision_config_proto); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, int64 feature_group_count, + absl::Span window_strides, Padding padding, + int64 feature_group_count, const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, + absl::Span> padding, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); - friend XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto); + friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, const PrecisionConfigProto* precision_config_proto); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); + absl::Span operands, const Shape& shape); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Not(const XlaOp& operand); - friend XlaOp ShiftLeft( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions); friend XlaOp ShiftRightArithmetic( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - friend XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); + friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions); friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); - friend XlaOp Reduce(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, + absl::Span dimensions_to_reduce); + friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); - friend XlaOp ReduceWindow( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); friend XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - friend XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups); - friend XlaOp CrossReplicaSum( - const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups, - const absl::optional& channel_id); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); + friend XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups); + friend XlaOp CrossReplicaSum(const XlaOp& operand, + const XlaComputation& computation, + absl::Span replica_groups, + const absl::optional& channel_id); friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); friend XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); - friend XlaOp SelectAndScatter( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + friend XlaOp SelectAndScatter(const XlaOp& operand, + const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); friend XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); friend XlaOp Abs(const XlaOp& operand); friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Exp(const XlaOp& operand); friend XlaOp Expm1(const XlaOp& operand); friend XlaOp Floor(const XlaOp& operand); @@ -1318,7 +1289,7 @@ class XlaBuilder { friend XlaOp Real(const XlaOp& operand); friend XlaOp Imag(const XlaOp& operand); friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); @@ -1329,16 +1300,14 @@ class XlaBuilder { PrimitiveType new_element_type); friend XlaOp Neg(const XlaOp& operand); friend XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); - friend XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span permutation); + friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); friend XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - friend XlaOp Map(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, + friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + absl::Span dimensions, + absl::Span static_operands); friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); @@ -1352,7 +1321,7 @@ class XlaBuilder { const int mantissa_bits); friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, @@ -1386,8 +1355,7 @@ class XlaBuilder { const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); - friend XlaOp AfterAll(XlaBuilder* builder, - tensorflow::gtl::ArraySlice tokens); + friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1451,8 +1419,7 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); template XlaOp ConstantR0(XlaBuilder* builder, NativeT value); template -XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values); +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values); XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); template XlaOp ConstantR2(XlaBuilder* builder, @@ -1501,8 +1468,7 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); // The new dimensions index into copies of the operand, i.e. // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); +XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); // Performs in-dimension-style broadcast. // @@ -1521,9 +1487,8 @@ XlaOp Broadcast(const XlaOp& operand, // will generate output // [1 , 1] // [2 , 2] -XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); +XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config @@ -1536,15 +1501,13 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); +XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); +XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -1564,8 +1527,7 @@ XlaOp Reshape(const XlaOp& operand, // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. -XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); +XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -1578,10 +1540,9 @@ XlaOp Collapse(const XlaOp& operand, // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice -XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); +XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to @@ -1602,7 +1563,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -1625,8 +1586,8 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. -XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, int64 dimension); +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, + int64 dimension); // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. @@ -1637,34 +1598,34 @@ void Trace(const string& tag, const XlaOp& operand); XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); // Enqueues a tuple-creation instruction onto the computation. -XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements); +XlaOp Tuple(XlaBuilder* builder, absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, @@ -1678,33 +1639,31 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, + absl::Span window_strides, Padding padding, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, + absl::Span window_strides, + absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); @@ -1712,11 +1671,9 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, const PrecisionConfigProto* precision_config_proto = nullptr); @@ -1724,7 +1681,7 @@ XlaOp ConvGeneralDilated( // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. @@ -1756,15 +1713,14 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, // Enqueues a call instruction onto the computation. XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Enqueues a custom call instruction onto the computation. // During code generation, a call instruction is emitted which targets a // symbol with the name |call_target_name|. The |operands| are passed to the // call instruction. |shape| is the resultant shape. XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); + absl::Span operands, const Shape& shape); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -1773,72 +1729,70 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Not(const XlaOp& operand); XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); -XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); -XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); // Reduces several arrays simultaneously among the provided dimensions, given // "computation" as a reduction operator. -XlaOp Reduce(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. @@ -1848,25 +1802,23 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. -XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_groups = {}); +XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -1887,7 +1839,7 @@ XlaOp CrossReplicaSum( // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_groups = {}, + absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. @@ -1910,27 +1862,26 @@ XlaOp CollectivePermute( // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); // As SelectAndScatter(), but the padding is given in the format // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); @@ -1977,7 +1928,7 @@ XlaOp Imag(const XlaOp& operand); // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., // not Inf or NaN. Defined only for floating-point types. Returns an array of @@ -2005,13 +1956,12 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); XlaOp Neg(const XlaOp& operand); // Enqueues a transpose instruction onto the computation. -XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); +XlaOp Transpose(const XlaOp& operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). -XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); +XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. // If only keys are provided: @@ -2036,10 +1986,9 @@ XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); // Enqueues a map instruction onto the computation. -XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, absl::Span dimensions, + absl::Span static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. @@ -2066,7 +2015,7 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, // Enqueues a Gather node onto the computation. XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -2124,7 +2073,7 @@ XlaOp CreateToken(XlaBuilder* builder); // Enqueues an AfterAll instruction which produces a token-shaped value and // takes a variadic number of token-shaped operands. The number of operands must // be greater than zero. Used for joining tokens. -XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens); +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -2172,7 +2121,7 @@ XlaOp XlaBuilder::ConstantR0(NativeT value) { } template -XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { +XlaOp XlaBuilder::ConstantR1(absl::Span values) { return ConstantLiteral(*LiteralUtil::CreateR1(values)); } @@ -2249,8 +2198,7 @@ XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { } template -XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values) { +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); } diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 693dcb3a3ee..3fadabcf520 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -27,7 +27,7 @@ limitations under the License. namespace xla { /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex( - const Shape& shape, tensorflow::gtl::ArraySlice multi_index) { + const Shape& shape, absl::Span multi_index) { DCHECK_EQ(shape.dimensions_size(), multi_index.size()); // Padding and nested layouts not supported yet. DCHECK_EQ(0, shape.layout().padded_dimensions_size()); @@ -118,8 +118,8 @@ namespace xla { return multi_index; } -/* static */ bool IndexUtil::BumpIndices( - const Shape& shape, tensorflow::gtl::MutableArraySlice indices) { +/* static */ bool IndexUtil::BumpIndices(const Shape& shape, + absl::Span indices) { for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) { int64 limit = shape.dimensions(dimno); if (indices[dimno] + 1 < limit) { @@ -149,8 +149,8 @@ namespace xla { return stride; } -/* static */ bool IndexUtil::IndexInBounds( - const Shape& shape, tensorflow::gtl::ArraySlice index) { +/* static */ bool IndexUtil::IndexInBounds(const Shape& shape, + absl::Span index) { int64 rank = ShapeUtil::Rank(shape); if (rank != index.size()) { return false; @@ -163,9 +163,8 @@ namespace xla { return true; } -/* static */ int IndexUtil::CompareIndices( - tensorflow::gtl::ArraySlice lhs, - tensorflow::gtl::ArraySlice rhs) { +/* static */ int IndexUtil::CompareIndices(absl::Span lhs, + absl::Span rhs) { int64 rank = lhs.size(); CHECK_EQ(rhs.size(), rank); for (int64 dim = 0; dim < rank; ++dim) { diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 142006f2626..d86437f83ca 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -35,7 +35,7 @@ class IndexUtil { // on the shape and its layout. The first index in the multi_index is // dimension 0. static int64 MultidimensionalIndexToLinearIndex( - const Shape& shape, tensorflow::gtl::ArraySlice multi_index); + const Shape& shape, absl::Span multi_index); // Converts a linear index into multidimensional index (eg {x, y, z}) based on // the shape and its layout. The first index in the returned multidimensional @@ -58,8 +58,7 @@ class IndexUtil { // // Returns true iff the indices were successfully bumped; false if we've hit // the limit where it can no longer be bumped in-bounds. - static bool BumpIndices(const Shape& shape, - tensorflow::gtl::MutableArraySlice indices); + static bool BumpIndices(const Shape& shape, absl::Span indices); // Calculates the stride size (in number of elements, not byte size) of a // given logical shape dimension (from 0 to rank-1). If available, padded @@ -71,15 +70,14 @@ class IndexUtil { // Returns true iff the given multi-index is contained in the bounds for the // shape. - static bool IndexInBounds(const Shape& shape, - tensorflow::gtl::ArraySlice index); + static bool IndexInBounds(const Shape& shape, absl::Span index); // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is // returned. - static int CompareIndices(tensorflow::gtl::ArraySlice lhs, - tensorflow::gtl::ArraySlice rhs); + static int CompareIndices(absl::Span lhs, + absl::Span rhs); private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index cce1838ef35..d310335618d 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -56,7 +56,7 @@ void SetDefaultLayoutToContainer( } // namespace /* static */ Layout LayoutUtil::MakeLayout( - tensorflow::gtl::ArraySlice minor_to_major) { + absl::Span minor_to_major) { Layout layout; layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { @@ -66,7 +66,7 @@ void SetDefaultLayoutToContainer( } /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( - tensorflow::gtl::ArraySlice major_to_minor) { + absl::Span major_to_minor) { Layout layout; layout.set_format(DENSE); for (int i = major_to_minor.size() - 1; i >= 0; i--) { @@ -307,7 +307,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return false; } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( +/* static */ absl::Span LayoutUtil::PaddedDimensions( const Shape& shape) { CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().padded_dimensions()); @@ -363,13 +363,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return protobuf_util::ProtobufEquals(lhs, rhs); } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( +/* static */ absl::Span LayoutUtil::MinorToMajor( const Shape& shape) { CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().minor_to_major()); } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( +/* static */ absl::Span LayoutUtil::MinorToMajor( const Layout& layout) { CHECK(layout.format() == DENSE); return AsInt64Slice(layout.minor_to_major()); @@ -472,7 +472,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } /* static */ bool LayoutUtil::AreDimensionsConsecutive( - const Layout& layout, tensorflow::gtl::ArraySlice dims) { + const Layout& layout, absl::Span dims) { CHECK(IsDense(layout)); std::vector positions_in_layout; for (int64 dim : dims) { diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 739bbe73675..2ffef8688d9 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -34,11 +34,11 @@ class LayoutUtil { public: // Creates a layout with the given minor-to-major dimension order. (This is a // convenience function for protobuf construction.) - static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + static Layout MakeLayout(absl::Span minor_to_major); // Similar to MakeLayout, but take indices in reverse order. static Layout MakeLayoutFromMajorToMinor( - tensorflow::gtl::ArraySlice major_to_minor); + absl::Span major_to_minor); // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) @@ -104,8 +104,7 @@ class LayoutUtil { // Returns the padded_dimensions array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice PaddedDimensions( - const Shape& shape); + static absl::Span PaddedDimensions(const Shape& shape); // Returns the given index of the padded_dimensions array for the given Shape. // Requires that the shape is an array and has a dense layout. @@ -138,8 +137,8 @@ class LayoutUtil { // Returns the minor_to_major array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice MinorToMajor(const Shape& shape); - static tensorflow::gtl::ArraySlice MinorToMajor(const Layout& layout); + static absl::Span MinorToMajor(const Shape& shape); + static absl::Span MinorToMajor(const Layout& layout); // Major(0) is the most major logical dimension number, Major(1) is the // second-most-major logical dimension number and so on. @@ -196,7 +195,7 @@ class LayoutUtil { // Returns whether the given dimensions are consecutive in the given layout, // not necessarily in the order given. static bool AreDimensionsConsecutive(const Layout& layout, - tensorflow::gtl::ArraySlice dims); + absl::Span dims); // Compute a hash for `layout`. static size_t Hash(const Layout& layout); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index e4c825450dc..f25dae6ff41 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -27,15 +27,15 @@ namespace { class LayoutUtilTest : public ::testing::Test { protected: Shape MakeShapeWithLayout(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + absl::Span dimensions, + absl::Span minor_to_major) { Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } Shape MakeShapeWithSparseLayout(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, + absl::Span dimensions, int64 max_sparse_elements) { Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 2fc3613650e..3f7635bd400 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -73,7 +73,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) { MutableLiteralBase::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions) + absl::Span dimensions) : dimensions(dimensions), base(dimensions.size(), 0), step(dimensions.size(), 1) { @@ -197,14 +197,13 @@ SparseIndexArray* MutableLiteralBase::sparse_indices( template Status MutableLiteralBase::CopySliceFromInternal( - const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { + const LiteralBase& src_literal, absl::Span src_base, + absl::Span dest_base, absl::Span copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); auto linear_index = [](const Shape& shape, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); }; @@ -232,7 +231,7 @@ Status MutableLiteralBase::CopySliceFromInternal( MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), copy_size); - auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { + auto copy_proc = [&](absl::Span indexes) { // Map from multi-dimensional index, to source index. std::transform(indexes.begin(), indexes.end(), src_base.begin(), src_indexes.begin(), std::plus()); @@ -257,10 +256,9 @@ Status MutableLiteralBase::CopySliceFromInternal( return Status::OK(); } -Status MutableLiteralBase::CopyElementFrom( - const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { +Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, + absl::Span src_index, + absl::Span dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( src_literal.shape(), src_index); @@ -355,9 +353,9 @@ namespace { // Copies the elements in 'src' to 'dest'. The shape and layout of the data in // the array slices are indicated by dest_shape and src_shape respectively. template -void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, - tensorflow::gtl::ArraySlice src, - const Shape& dest_shape, const Shape& src_shape) { +void CopyElementsBetween(absl::Span dest, + absl::Span src, const Shape& dest_shape, + const Shape& src_shape) { CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; @@ -487,11 +485,10 @@ Status Literal::MoveFrom(Literal&& src_literal, return Status::OK(); } -Status MutableLiteralBase::CopySliceFrom( - const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size) { TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) << ShapeUtil::HumanString(src_literal.shape()); @@ -591,8 +588,7 @@ std::unique_ptr LiteralBase::Relayout( } StatusOr> LiteralBase::Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const { + const Shape& result_shape, absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Broadcast only supports arrays."); } @@ -615,7 +611,7 @@ StatusOr> LiteralBase::Broadcast( ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); ShapeUtil::ForEachIndex( - result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + result_shape, [&](absl::Span output_index) { for (int64 i = 0; i < dimensions.size(); ++i) { scratch_source_index[i] = output_index[dimensions[i]]; } @@ -632,7 +628,7 @@ StatusOr> LiteralBase::Broadcast( } StatusOr> LiteralBase::Reshape( - tensorflow::gtl::ArraySlice dimensions) const { + absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } @@ -661,7 +657,7 @@ StatusOr> LiteralBase::Reshape( } std::unique_ptr LiteralBase::Transpose( - tensorflow::gtl::ArraySlice permutation) const { + absl::Span permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; @@ -700,12 +696,11 @@ std::unique_ptr LiteralBase::Transpose( template std::unique_ptr LiteralBase::SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const { + const Shape& result_shape, absl::Span start_indices) const { auto result_literal = absl::make_unique(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + [&](absl::Span indices, NativeT /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } @@ -716,8 +711,8 @@ std::unique_ptr LiteralBase::SliceInternal( } std::unique_ptr LiteralBase::Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const { + absl::Span start_indices, + absl::Span limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; @@ -761,7 +756,7 @@ std::unique_ptr LiteralBase::CloneToUnique() const { return result; } -string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, +string LiteralBase::GetAsString(absl::Span multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); @@ -858,7 +853,7 @@ string LiteralBase::GetSparseElementAsString( } StatusOr LiteralBase::GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const { + absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -900,8 +895,8 @@ size_t LiteralBase::Hash() const { return hash_value; } -Status MutableLiteralBase::SetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index, int64 value) { +Status MutableLiteralBase::SetIntegralAsS64(absl::Span multi_index, + int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -929,7 +924,7 @@ Status MutableLiteralBase::SetIntegralAsS64( return Status::OK(); } -tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( +absl::Span LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -998,7 +993,7 @@ void LiteralBase::Piece::SortSparseElementsInternal() { auto values = data(); CHECK_LE(num_elements, values.size()); sparse_indices()->SortWithValues( - tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); + absl::Span(values.data(), num_elements)); } namespace { @@ -1064,8 +1059,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, CHECK(LayoutUtil::IsDenseArray(subshape)); - auto element_to_string = - [&](tensorflow::gtl::ArraySlice indices) -> string { + auto element_to_string = [&](absl::Span indices) -> string { PrimitiveType element_type = subshape.element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. @@ -1160,7 +1154,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {"); literal.EachCellAsString( - [&](tensorflow::gtl::ArraySlice indices, const string& value) { + [&](absl::Span indices, const string& value) { pieces->push_back(" "); pieces->push_back(value); }); @@ -1183,7 +1177,7 @@ string LiteralBase::ToString(bool print_layout) const { } void LiteralBase::EachCellAsString( - const std::function indices, + const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::IsZeroElementArray(shape())) { return; @@ -1250,10 +1244,8 @@ std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; - tensorflow::gtl::ArraySlice src_data = - src_literal.data(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->data(); + absl::Span src_data = src_literal.data(); + absl::Span dest_data = result_literal->data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); @@ -1397,7 +1389,7 @@ StatusOr> LiteralBase::ConvertToShape( } /* static */ Literal MutableLiteralBase::MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (const Literal& element : elements) { element_shapes.push_back(element.shape()); @@ -1488,7 +1480,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const { namespace { template -static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, +static bool AllElementsEqualValue(absl::Span data, NativeT value) { for (int64 i = 0; i < data.size(); ++i) { if (data[i] != value) { @@ -1742,7 +1734,7 @@ bool LiteralBase::IsR1Iota() const { return true; } -bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsZero(absl::Span indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1778,7 +1770,7 @@ namespace { template void CopyToRepeatedField(RepeatedFieldT* dest, - const tensorflow::gtl::ArraySlice src) { + const absl::Span src) { *dest = RepeatedFieldT(src.begin(), src.end()); } @@ -1856,7 +1848,7 @@ void* LiteralBase::Piece::untyped_data() { namespace { template -Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, +Status CopyFromRepeatedField(absl::Span dest, const RepeatedFieldT& src) { if (dest.size() != src.size()) { return InvalidArgument( @@ -2126,8 +2118,8 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) root_piece_.set_subshape(shape_.get()); } -BorrowingLiteral::BorrowingLiteral( - tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) +BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, + const Shape& shape) : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsTuple(*shape_)); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index c6ef99db0ff..d690241e4ef 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -70,13 +70,12 @@ class LiteralBase { // Serialize to proto. LiteralProto ToProto() const; - // Returns an ArraySlice of the array for this literal for the given NativeT + // Returns a Span of the array for this literal for the given NativeT // (e.g., float). CHECKs if the subshape of the literal at the given // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type // to native type. template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + absl::Span data(const ShapeIndex& shape_index = {}) const; // Returns a const pointer to the sparse index array. Returns nullptr if the // literal is not a sparse array. @@ -100,12 +99,12 @@ class LiteralBase { // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, + NativeT Get(absl::Span multi_index, const ShapeIndex& shape_index) const; // Overloads of Get for array literals. CHECKs if the literal is not // array-shaped and dense. template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + NativeT Get(absl::Span multi_index) const; // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -114,7 +113,7 @@ class LiteralBase { // As Get(), but determines the correct type and converts the value // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, + string GetAsString(absl::Span multi_index, const ShapeIndex& shape_index = {}) const; // As GetSparseElement(), but determines the correct type and converts the // value into text. @@ -122,14 +121,13 @@ class LiteralBase { const ShapeIndex& shape_index = {}) const; // As Get(), but determines the correct type and converts the value into // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; + StatusOr GetIntegralAsS64(absl::Span multi_index) const; // Returns the multi-index of the element in a sparse literal at the given // sparse element number. The sparse element number is the position with in // the sparse array's list of (index, value) pairs, and is checked against the // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( + absl::Span GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; // Returns the value of the element in a sparse literal at the given sparse @@ -150,12 +148,12 @@ class LiteralBase { // // This literal must have a dense layout. void EachCellAsString( - const std::function indices, + const std::function indices, const string& value)>& per_cell) const; template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; + void EachCell( + std::function indices, NativeT value)> + per_cell) const; // Returns whether every element in this literal is equal to value. // @@ -200,7 +198,7 @@ class LiteralBase { // Returns whether this literal is zero at the specified index. This literal // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + bool IsZero(absl::Span indices) const; // Returns the count of the elements in the array at the given shape index in // this literal. @@ -273,13 +271,12 @@ class LiteralBase { // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; + absl::Span dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. StatusOr> Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const; + const Shape& result_shape, absl::Span dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -288,8 +285,7 @@ class LiteralBase { // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; + std::unique_ptr Transpose(absl::Span permutation) const; // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the @@ -297,9 +293,8 @@ class LiteralBase { // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; + std::unique_ptr Slice(absl::Span start_indices, + absl::Span limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this @@ -328,9 +323,9 @@ class LiteralBase { // Returns the buffer holding the array data for this piece as an array // slice. This piece must be array-shaped. template - tensorflow::gtl::ArraySlice data() const; + absl::Span data() const; template - tensorflow::gtl::MutableArraySlice data(); + absl::Span data(); // Returns the buffer holding the array data for this piece as a void*. This // piece must be array-shaped. @@ -341,9 +336,9 @@ class LiteralBase { // is CHECKed against the dimension sizes of the array. This piece must be // array-shaped. template - NativeT Get(tensorflow::gtl::ArraySlice index) const; + NativeT Get(absl::Span index) const; template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); + void Set(absl::Span index, NativeT value); // Gets/sets the buffer holding the array data. char* buffer() const { return buffer_; } @@ -545,8 +540,7 @@ class LiteralBase { private: template std::unique_ptr SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const; + const Shape& result_shape, absl::Span start_indices) const; }; // Abstract base class representing a mutable literal in XLA. @@ -554,13 +548,12 @@ class MutableLiteralBase : public LiteralBase { public: virtual ~MutableLiteralBase() = 0; - // Returns a MutableArraySlice view of the array for this literal for the + // Returns a Span view of the array for this literal for the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the // given ShapeIndex is not array. See primitive_util.h for the mapping from // XLA type to native type. template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + absl::Span data(const ShapeIndex& shape_index = {}); // Unhide const method from parent class. using LiteralBase::data; @@ -587,8 +580,7 @@ class MutableLiteralBase : public LiteralBase { // are populated. template void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); + absl::Span values, bool sort = true); // Copy values from 'src_literal' rooted at 'src_shape_index' into this // literal rooted at 'dest_shape_index'. The subshape of this literal rooted @@ -609,39 +601,38 @@ class MutableLiteralBase : public LiteralBase { // corresponding base indices being 0. // This literal and 'src_literal' must be arrays. Status CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size); // Copies one element from src_literal[src_index] to (*this)[dest_index]. Status CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); + absl::Span src_index, + absl::Span dest_index); // Sets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); + void Set(absl::Span multi_index, const ShapeIndex& shape_index, + NativeT value); // Overloads of Set for array literals. CHECKs if the literal is not // array-shaped and dense. template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + void Set(absl::Span multi_index, NativeT value); // Appends the given element to the literal. If the elements are not appended // in sorted order, then SortSparseElements should be called before calling // other methods. This literal must have a sparse layout. template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); + void AppendSparseElement(absl::Span multi_index, NativeT value, + const ShapeIndex& shape_index = {}); // Sorts the elements in a sparse array. void SortSparseElements(const ShapeIndex& shape_index = {}); // As Set(), but truncates `value` to the literal element type before storing. // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); + Status SetIntegralAsS64(absl::Span multi_index, int64 value); // Populate this literal with the given values. Examples: // @@ -656,7 +647,7 @@ class MutableLiteralBase : public LiteralBase { // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 // array of S32. template - void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(absl::Span values); void PopulateR1(const tensorflow::core::Bitmap& values); template void PopulateR2(std::initializer_list> values); @@ -673,7 +664,7 @@ class MutableLiteralBase : public LiteralBase { // in this literal object. // // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // NativeT(absl::Span indexes) or compatible. // // This literal must have a dense layout. template @@ -693,8 +684,7 @@ class MutableLiteralBase : public LiteralBase { // moved into the tuple elements of a new tuple-shaped Literal which is // returned. Upon return, each of the Literals in 'elements' is set to a nil // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); + static Literal MoveIntoTuple(absl::Span elements); // Serialize from a proto. static StatusOr> CreateFromProto( @@ -712,20 +702,20 @@ class MutableLiteralBase : public LiteralBase { // arguments one by one. template Status CopySliceFromInternal(const LiteralBase& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size); // Utility structure which is used to create the optimal configuration for // a ShapeUtil::ForEachIndex() scan across two literals. struct StrideConfig { StrideConfig(const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // The dimensions of the stride operation. Essentially every dimension // will be iterated from base[i] to base[i]+dimensions[i], in step[i] // steps. - tensorflow::gtl::ArraySlice dimensions; + absl::Span dimensions; DimensionVector base; DimensionVector step; int64 minor_dimension = 0; @@ -854,7 +844,7 @@ class BorrowingLiteral : public LiteralBase { // This constructor is only used for array shapes. BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); // Similar as above, except to be used for constructing non-nested tuples. - BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + BorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape); // TODO(b/79707221): adding constructors for nested tuples as well. @@ -874,7 +864,7 @@ class BorrowingLiteral : public LiteralBase { }; template -tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { +absl::Span LiteralBase::Piece::data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -882,12 +872,12 @@ tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), element_count()); + return absl::Span(reinterpret_cast(buffer()), + element_count()); } template -tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { +absl::Span LiteralBase::Piece::data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -895,20 +885,19 @@ tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), element_count()); + return absl::Span(reinterpret_cast(buffer()), + element_count()); } template -NativeT LiteralBase::Piece::Get( - tensorflow::gtl::ArraySlice multi_index) const { +NativeT LiteralBase::Piece::Get(absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)]; } template -void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, +void LiteralBase::Piece::Set(absl::Span multi_index, NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -916,39 +905,37 @@ void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, } template -tensorflow::gtl::ArraySlice LiteralBase::data( +absl::Span LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } template -tensorflow::gtl::MutableArraySlice MutableLiteralBase::data( - const ShapeIndex& shape_index) { +absl::Span MutableLiteralBase::data(const ShapeIndex& shape_index) { return piece(shape_index).data(); } template -inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, +inline NativeT LiteralBase::Get(absl::Span multi_index, const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT LiteralBase::Get( - tensorflow::gtl::ArraySlice multi_index) const { +inline NativeT LiteralBase::Get(absl::Span multi_index) const { return root_piece().Get(multi_index); } template -inline void MutableLiteralBase::Set( - tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { +inline void MutableLiteralBase::Set(absl::Span multi_index, + const ShapeIndex& shape_index, + NativeT value) { return piece(shape_index).Set(multi_index, value); } template -inline void MutableLiteralBase::Set( - tensorflow::gtl::ArraySlice multi_index, NativeT value) { +inline void MutableLiteralBase::Set(absl::Span multi_index, + NativeT value) { return root_piece().Set(multi_index, value); } @@ -967,7 +954,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, template void MutableLiteralBase::AppendSparseElement( - tensorflow::gtl::ArraySlice multi_index, NativeT value, + absl::Span multi_index, NativeT value, const ShapeIndex& shape_index) { Piece& p = piece(shape_index); const Shape& subshape = p.subshape(); @@ -983,8 +970,7 @@ void MutableLiteralBase::AppendSparseElement( template void LiteralBase::EachCell( - std::function indices, - NativeT value)> + std::function indices, NativeT value)> per_cell) const { if (ShapeUtil::IsZeroElementArray(shape())) { return; @@ -996,8 +982,7 @@ void LiteralBase::EachCell( } template -inline void MutableLiteralBase::PopulateR1( - tensorflow::gtl::ArraySlice values) { +inline void MutableLiteralBase::PopulateR1(absl::Span values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); @@ -1042,8 +1027,9 @@ void MutableLiteralBase::PopulateFromArray(const Array& values) { for (int dim = 0; dim < values.num_dimensions(); ++dim) { CHECK_EQ(values.dim(dim), shape().dimensions(dim)); } - values.Each([this](tensorflow::gtl::ArraySlice indices, - NativeT value) { this->Set(indices, value); }); + values.Each([this](absl::Span indices, NativeT value) { + this->Set(indices, value); + }); } template @@ -1062,9 +1048,9 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { } template -void MutableLiteralBase::PopulateSparse( - SparseIndexArray indices, tensorflow::gtl::ArraySlice values, - bool sort) { +void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, + absl::Span values, + bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); int rank = ShapeUtil::Rank(shape()); CHECK_EQ(indices.rank(), rank); @@ -1074,7 +1060,7 @@ void MutableLiteralBase::PopulateSparse( CHECK_LE(num_elements, max_elements); CHECK_EQ(num_elements, indices.index_count()); auto root_data = root_piece().data(); - // Piece::data() returns an ArraySlice of size equal to the number of indices + // Piece::data() returns a Span of size equal to the number of indices // in the SparseIndexArray. So there is no need to adjust the size of the data // here. It is enough to just copy the incoming values into the data buffer. std::copy(values.begin(), values.end(), root_data.begin()); @@ -1094,14 +1080,14 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator, TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice literal_data = data(); + absl::Span literal_data = data(); if (rank > 0) { StrideConfig stride_config(this_shape, this_shape, AsInt64Slice(this_shape.dimensions())); int64 minor_dimension_size = ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + auto init_function = [&](absl::Span indexes) { DimensionVector minor_scan_indexes(rank, 0); const int64 index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); @@ -1119,7 +1105,7 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator, ShapeUtil::ForEachIndex( this_shape, stride_config.base, stride_config.dimensions, stride_config.step, - [&init_function](tensorflow::gtl::ArraySlice indexes) { + [&init_function](absl::Span indexes) { init_function(indexes); return true; }); @@ -1165,7 +1151,7 @@ std::unique_ptr LiteralBase::Replicate(int64 times) const { } DimensionVector output_indices(bounds.size(), 0); - tensorflow::gtl::ArraySlice input_indices = output_indices; + absl::Span input_indices = output_indices; input_indices.remove_prefix(1); bool done = false; diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index f6ce69eaee8..3d8725ed705 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -38,8 +38,8 @@ namespace { // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual( - FloatT lhs, FloatT rhs, tensorflow::gtl::ArraySlice multi_index) { +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); auto lhs_double = static_cast(lhs); @@ -60,7 +60,7 @@ Status CompareFloatsBitwiseEqual( // default gunit implementation). template Status CompareEqual(NativeT lhs, NativeT rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { if (lhs == rhs) { return Status::OK(); } @@ -74,28 +74,27 @@ Status CompareEqual(NativeT lhs, NativeT rhs, // comparison is requested. template <> Status CompareEqual(bfloat16 lhs, bfloat16 rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual( - Eigen::half lhs, Eigen::half rhs, - tensorflow::gtl::ArraySlice multi_index) { +Status CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> Status CompareEqual(float lhs, float rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> Status CompareEqual(double lhs, double rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> Status CompareEqual(complex64 lhs, complex64 rhs, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); if (!res.ok()) { return res; @@ -108,8 +107,7 @@ Status CompareEqual(complex64 lhs, complex64 rhs, // elements are equal. template Status Equal(LiteralSlice expected, LiteralSlice actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { + absl::Span multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); @@ -305,8 +303,7 @@ class NearComparator { } // Insert the given error into the given error bucket vector. - void UpdateErrorBucket( - float error, tensorflow::gtl::MutableArraySlice error_buckets) { + void UpdateErrorBucket(float error, absl::Span error_buckets) { CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < error_buckets.size(); ++i) { if (error >= kErrorBucketBounds[i]) { @@ -410,10 +407,8 @@ class NearComparator { // Fast path optimization for the case were layouts match. if (LayoutUtil::Equal(actual_.shape().layout(), expected_.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected_.data(); - tensorflow::gtl::ArraySlice actual_data = - actual_.data(); + absl::Span expected_data = expected_.data(); + absl::Span actual_data = actual_.data(); const int64 len = expected_data.size(); for (int64 i = 0; i < len; ++i) { CompareValues(expected_data[i], actual_data[i], i); @@ -488,7 +483,7 @@ class NearComparator { } auto print_accum_buckets = [&](const string& header, int64 total, - tensorflow::gtl::ArraySlice buckets) { + absl::Span buckets) { StrAppend(&out, header, ":\n"); StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0], total - buckets[0], diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 5d7d4dbb36e..922196b1fd3 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -36,7 +36,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::ArraySlice; using ::testing::ElementsAre; using ::testing::HasSubstr; @@ -222,9 +221,9 @@ TEST_F(LiteralUtilTest, CreateSparse) { std::vector expected_values = {8, 9, 7, 10}; EXPECT_EQ(literal->sparse_indices()->data(), - ArraySlice(expected_indices.data(), - expected_indices.num_elements())); - EXPECT_EQ(literal->data(), ArraySlice(expected_values)); + absl::Span(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal->data(), absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -296,7 +295,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format on std::vector> seen; literal->EachCellAsString( - [&seen](ArraySlice indices, const string& value) { + [&seen](absl::Span indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -649,7 +648,7 @@ TEST_F(LiteralUtilTest, TransposeR4) { // clang-format on auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell([&](ArraySlice indices, float value) { + reshape->EachCell([&](absl::Span indices, float value) { EXPECT_EQ(value, original->Get( {indices[2], indices[3], indices[0], indices[1]})); }); @@ -889,7 +888,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; - auto init_proc = [&](ArraySlice indexes) { + auto init_proc = [&](absl::Span indexes) { source->Set(indexes, ++seqnr); return true; }; @@ -905,7 +904,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; - auto check_proc = [&](ArraySlice indexes) { + auto check_proc = [&](absl::Span indexes) { std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); std::transform(source_indexes.begin(), source_indexes.end(), src_base, source_indexes.begin(), std::plus()); @@ -1093,7 +1092,7 @@ TEST_F(LiteralUtilTest, Populate) { primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); auto literal = absl::make_unique(shape); - auto generator = [&](ArraySlice indexes) -> uint32 { + auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), @@ -1105,7 +1104,7 @@ TEST_F(LiteralUtilTest, Populate) { std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](ArraySlice indexes) { + auto check_function = [&](absl::Span indexes) { auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; @@ -1135,7 +1134,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); auto literal = absl::make_unique(shape); - auto generator = [&](ArraySlice indexes) -> uint32 { + auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), @@ -1147,7 +1146,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](ArraySlice indexes) { + auto check_function = [&](absl::Span indexes) { auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 931d2c631bc..613449cf10c 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -84,8 +84,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace /* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions) { + PrimitiveType primitive_type, absl::Span dimensions) { return Literal::CreateFromShape( ShapeUtil::MakeShape(primitive_type, dimensions)); } @@ -301,9 +300,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::ReshapeSlice( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const LiteralSlice& literal) { + absl::Span new_dimensions, + absl::Span minor_to_major, const LiteralSlice& literal) { int64 new_num_elements = 1; for (int64 i = 0; i < new_dimensions.size(); ++i) { new_num_elements *= new_dimensions[i]; @@ -430,7 +428,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::MakeTuple( - tensorflow::gtl::ArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (const auto* element : elements) { element_shapes.push_back(element->shape()); @@ -444,7 +442,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( - tensorflow::gtl::ArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (const auto& element : elements) { element_shapes.push_back(element.shape()); @@ -474,7 +472,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ string LiteralUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return StrCat("{", absl::StrJoin(multi_index, ","), "}"); } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 3d28c070f29..fa336ae0de9 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -71,8 +71,7 @@ class LiteralUtil { template static std::unique_ptr CreateR0(NativeT value); template - static std::unique_ptr CreateR1( - tensorflow::gtl::ArraySlice values); + static std::unique_ptr CreateR1(absl::Span values); static std::unique_ptr CreateR1( const tensorflow::core::Bitmap& values); template @@ -141,8 +140,8 @@ class LiteralUtil { // template static std::unique_ptr CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort = true); + absl::Span dimensions, SparseIndexArray indices, + absl::Span values, bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -157,7 +156,7 @@ class LiteralUtil { // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithDescendingLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value); + absl::Span dimensions, NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear @@ -215,10 +214,10 @@ class LiteralUtil { // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. static std::unique_ptr MakeTuple( - tensorflow::gtl::ArraySlice elements); + absl::Span elements); static std::unique_ptr MakeTupleFromSlices( - tensorflow::gtl::ArraySlice elements); + absl::Span elements); // As above, but intended to be invoked with move semantics; i.e. // @@ -259,8 +258,7 @@ class LiteralUtil { // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); + PrimitiveType primitive_type, absl::Span dimensions); // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, @@ -279,9 +277,8 @@ class LiteralUtil { // buffer of the input literal is assumed to have the given minor_to_major // layout order. static std::unique_ptr ReshapeSlice( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const LiteralSlice& literal); + absl::Span new_dimensions, + absl::Span minor_to_major, const LiteralSlice& literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -291,7 +288,7 @@ class LiteralUtil { typename T = typename primitive_util::PrimitiveTypeToNative::type> static StatusOr> CreateRandomLiteral( const Shape& shape, - const std::function)>& generator); + const std::function)>& generator); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -319,8 +316,7 @@ class LiteralUtil { // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); + static string MultiIndexAsString(absl::Span multi_index); }; std::ostream& operator<<(std::ostream& out, const Literal& literal); @@ -335,7 +331,7 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateR1( - tensorflow::gtl::ArraySlice values) { + absl::Span values) { auto literal = absl::make_unique( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); @@ -427,8 +423,8 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort) { + absl::Span dimensions, SparseIndexArray indices, + absl::Span values, bool sort) { int64 num_elements = values.size(); int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); @@ -570,8 +566,8 @@ template template /* static */ std::unique_ptr -LiteralUtil::CreateFullWithDescendingLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value) { +LiteralUtil::CreateFullWithDescendingLayout(absl::Span dimensions, + NativeT value) { auto literal = absl::make_unique(ShapeUtil::MakeShapeWithDescendingLayout( primitive_util::NativeToPrimitiveType(), dimensions)); @@ -583,14 +579,12 @@ template /* static */ StatusOr> LiteralUtil::CreateRandomLiteral( const Shape& shape, - const std::function)>& generator) { + const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); auto literal = absl::make_unique(shape); TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); + [&](absl::Span indexes) { return generator(indexes); })); return std::move(literal); } @@ -601,9 +595,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); + shape, + [&](absl::Span /*indexes*/) { return generator(*engine); }); } template diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 6e42775f6fb..83429b8fd3a 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -61,7 +61,7 @@ StatusOr> PackedLiteralReader::Read( result->PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); - tensorflow::gtl::ArraySlice field = result->data(); + absl::Span field = result->data(); char* data = tensorflow::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); tensorflow::StringPiece sp; // non-absl OK diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b5fd747cfab..cd6e20b6936 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -259,7 +259,7 @@ StatusOr> CompiledLocalComputation::Execute( } LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( - tensorflow::gtl::ArraySlice argument_handles) { + absl::Span argument_handles) { LocalClient* client = GetOrCreateLocalClient(); std::vector argument_buffers; @@ -369,8 +369,7 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { } LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { + const LocalOp& operand, absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } @@ -380,14 +379,14 @@ LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { +LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } @@ -395,10 +394,10 @@ LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { return xla::CrossReplicaSum(operand.op()); } -LocalOp LocalComputationBuilder::Slice( - const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { +LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } @@ -411,7 +410,7 @@ LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, LocalOp LocalComputationBuilder::DynamicSlice( const LocalOp& operand, const LocalOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } @@ -421,8 +420,8 @@ LocalOp LocalComputationBuilder::DynamicUpdateSlice( return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, int64 dimension) { +LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -433,18 +432,16 @@ LocalOp LocalComputationBuilder::ConcatInDim( LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( const LocalOp& operand, const LocalComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const LocalOp& source, const LocalOp& init_value, - const LocalComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const LocalOp& source, + const LocalOp& init_value, const LocalComputation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { +LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -471,10 +468,9 @@ LocalOp LocalComputationBuilder::DotGeneral( LocalOp LocalComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers); @@ -490,9 +486,8 @@ LocalOp LocalComputationBuilder::BitcastConvertType( return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { +LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -502,19 +497,18 @@ LocalOp LocalComputationBuilder::Call( } LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + const LocalOp& operand, absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Map(absl::Span operands, + const LocalComputation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -528,7 +522,7 @@ LocalOp LocalComputationBuilder::Map( LocalOp LocalComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } @@ -536,9 +530,9 @@ LocalOp LocalComputationBuilder::Reduce( LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), window_dimensions, window_strides, padding); @@ -599,10 +593,10 @@ StatusOr LocalComputationBuilder::BuildConstantSubGraph( #define _FORWARD_UNOP(method_name) \ _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) -#define _FORWARD_BINOP(method_name) \ - _FORWARD(method_name, LocalOp, \ - (const LocalOp& lhs, const LocalOp& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + absl::Span broadcast_dimensions), \ (lhs.op(), rhs.op(), broadcast_dimensions)) #define _FORWARD_TRIOP(method_name) \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index d9543b958dc..df135935c37 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -122,7 +122,7 @@ class CompiledLocalComputation { const std::vector >& shapes_with_layout); LocalShapedBuffer* ExecuteWithShapedBuffers( - tensorflow::gtl::ArraySlice argument_handles); + absl::Span argument_handles); private: std::unique_ptr executable_; @@ -199,46 +199,41 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); LocalOp Broadcast(const LocalOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config); - LocalOp Reshape(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, absl::Span dimensions, + absl::Span new_sizes); - LocalOp Collapse(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); LocalOp CrossReplicaSum(const LocalOp& operand); - LocalOp Slice(const LocalOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); LocalOp SliceInDim(const LocalOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, const LocalOp& start_indices); - LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( const LocalOp& operand, const LocalComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding, - const LocalOp& source, const LocalOp& init_value, - const LocalComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span > padding, const LocalOp& source, + const LocalOp& init_value, const LocalComputation& scatter); - LocalOp Tuple(tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(absl::Span elements); LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); @@ -249,10 +244,10 @@ class LocalComputationBuilder { LocalOp ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + absl::Span window_strides, + absl::Span > padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); LocalOp ConvertElementType(const LocalOp& operand, @@ -262,28 +257,27 @@ class LocalComputationBuilder { PrimitiveType new_element_type); LocalOp Call(const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); LocalOp Transpose(const LocalOp& operand, - tensorflow::gtl::ArraySlice permutation); + absl::Span permutation); - LocalOp Rev(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, absl::Span dimensions); - LocalOp Map(tensorflow::gtl::ArraySlice operands, + LocalOp Map(absl::Span operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span > padding); LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, const Shape& shape); @@ -316,7 +310,7 @@ class LocalComputationBuilder { #define _FORWARD_BINOP(method_name) \ _FORWARD(method_name, LocalOp, \ (const LocalOp& lhs, const LocalOp& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) + absl::Span broadcast_dimensions)) #define _FORWARD_TRIOP(method_name) \ _FORWARD(method_name, LocalOp, \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index f6169ebf190..e6034296d73 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,15 +22,15 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ArraySlice <- sequence of int -// ArraySlice <- sequence of LocalOp +// Span <- sequence of int +// Span <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) // <- object duck-typed as xla_client.Shape // std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int -// ArraySlice> <- sequence of int pairs +// Span> <- sequence of int pairs // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // DotDimensionNumbers proto <- corresponding Python proto @@ -267,9 +267,9 @@ tensorflow::ImportNumpy(); $result = Py_None; } -// ArraySlice +// Span -%typemap(in) tensorflow::gtl::ArraySlice +%typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -299,9 +299,9 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ArraySlice +// Span -%typemap(in) tensorflow::gtl::ArraySlice( +%typemap(in) absl::Span( std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -323,7 +323,7 @@ tensorflow::ImportNumpy(); // LocalShapedBuffer* -%typemap(in) tensorflow::gtl::ArraySlice +%typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -496,9 +496,9 @@ tensorflow::ImportNumpy(); $1 = static_cast(value); } -// ArraySlice> +// Span> -%typemap(in) tensorflow::gtl::ArraySlice > +%typemap(in) absl::Span > (std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 3de7ee2bc8c..a4854f593f0 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -108,17 +108,15 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( // array by adding a fourth dummy dimension of size 1 without stride, padding // and dilation. Array4D a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1); - a4dlhs.Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); - }); + a4dlhs.Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); + }); Array4D a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1); - a4drhs.Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); - }); + a4drhs.Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); + }); // Add a second dummy spatial dimensions. ConvolutionDimensionNumbers dnums2d = dnums; dnums2d.add_input_spatial_dimensions(3); @@ -130,11 +128,10 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( auto convr3 = absl::make_unique>( convr4->planes(), convr4->depth(), convr4->height()); - convr4->Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; - }); + convr4->Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; + }); return convr3; } @@ -189,11 +186,11 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + const absl::Span& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding) { std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); @@ -221,10 +218,11 @@ ReferenceUtil::ReduceWindow1DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DAdd( - const tensorflow::gtl::ArraySlice& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { +ReferenceUtil::ReduceWindow1DAdd(const absl::Span& operand, + float init, + const absl::Span& window, + const absl::Span& stride, + Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{static_cast(operand.size())}; return ReduceWindow1DGeneric( @@ -236,9 +234,9 @@ ReferenceUtil::ReduceWindow1DAdd( ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding) { std::vector dim_lengths{operand.height(), operand.width()}; std::vector window_counts(window.size(), 0); @@ -276,8 +274,8 @@ ReferenceUtil::ReduceWindow2DGeneric( /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( const Array2D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const absl::Span& window, + const absl::Span& stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{operand.height(), operand.width()}; return ReduceWindow2DGeneric( @@ -287,8 +285,8 @@ ReferenceUtil::ReduceWindow2DGeneric( /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( const Array3D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const absl::Span& window, + const absl::Span& stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -334,8 +332,8 @@ ReferenceUtil::ReduceWindow2DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const absl::Span& window, + const absl::Span& stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; return ReduceWindow4DGeneric( @@ -347,9 +345,9 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; @@ -402,8 +400,8 @@ ReferenceUtil::ReduceWindow4DGeneric( /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( const Array4D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const absl::Span& window, + const absl::Span& stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, padding); @@ -424,10 +422,12 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::SelectAndScatter4DGePlus( - const Array4D& operand, const Array4D& source, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, bool same_padding) { +ReferenceUtil::SelectAndScatter4DGePlus(const Array4D& operand, + const Array4D& source, + float init, + const absl::Span& window, + const absl::Span& stride, + bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; auto result = absl::make_unique>(operand.n1(), operand.n2(), operand.n3(), operand.n4()); @@ -591,7 +591,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( result_literal->shape().dimensions(2), result_literal->shape().dimensions(3)); - result->Each([&](tensorflow::gtl::ArraySlice indices, float* value) { + result->Each([&](absl::Span indices, float* value) { *value = result_literal->Get(indices); }); @@ -633,8 +633,7 @@ ReferenceUtil::ReduceToRowArray2D( } /*static*/ std::vector ReferenceUtil::Reduce4DTo1D( - const Array4D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array4D& array, float init, absl::Span dims, const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); @@ -707,8 +706,7 @@ ReferenceUtil::ReduceToRowArray2D( } /* static */ std::unique_ptr> ReferenceUtil::Reduce3DTo2D( - const Array3D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array3D& array, float init, absl::Span dims, const std::function& reduce_function) { CHECK_EQ(dims.size(), 1); int64 rows = dims[0] == 0 ? array.n2() : array.n1(); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 88f853a3591..b1d530c59eb 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -144,8 +144,7 @@ class ReferenceUtil { // Returns the result of reducing the 4D array to a vector, reducing away // the dimensions specified in dims. static std::vector Reduce4DTo1D( - const Array4D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array4D& array, float init, absl::Span dims, const std::function& reduce_function); // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`. @@ -156,8 +155,7 @@ class ReferenceUtil { // Returns the result of reducing the 3D array to a 2D array, reducing away // the dimensions specified in dims. static std::unique_ptr> Reduce3DTo2D( - const Array3D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array3D& array, float init, absl::Span dims, const std::function& reduce_function); // Applies map_function to each element in the input (2D array) and returns @@ -179,47 +177,47 @@ class ReferenceUtil { // Windowed reductions with Add as the function to apply. static std::unique_ptr> ReduceWindow1DAdd( - const tensorflow::gtl::ArraySlice& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& operand, float init, + const absl::Span& window, + const absl::Span& stride, Padding padding); static std::unique_ptr> ReduceWindow2DAdd( const Array2D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& window, + const absl::Span& stride, Padding padding); static std::unique_ptr> ReduceWindow3DAdd( const Array3D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& window, + const absl::Span& stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( const Array4D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& window, + const absl::Span& stride, Padding padding); // Windowed reductions with a generic reduce function. static std::unique_ptr> ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + const absl::Span& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding); static std::unique_ptr> ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& window, + const absl::Span& stride, Padding padding); // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding); // Batch normalize data. static std::unique_ptr> BatchNorm4D( @@ -232,8 +230,8 @@ class ReferenceUtil { // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, bool same_padding); + const absl::Span& window, + const absl::Span& stride, bool same_padding); // Concatenates the lhs and rhs arrays along the concatenate_dimension. // E.g. if concatenate_dimension is 0, the "n1"/height dimension is @@ -334,8 +332,8 @@ class ReferenceUtil { // Slices with index clamping template - static std::vector ClampSlice1D( - const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + static std::vector ClampSlice1D(const absl::Span& input, + int64 start, int64 size) { start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { @@ -633,7 +631,7 @@ class ReferenceUtil { Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], output_bounds[3]); result.Each( - [&](tensorflow::gtl::ArraySlice indices, NativeT* value) { + [&](absl::Span indices, NativeT* value) { for (int i = 0; i < 4; ++i) { bool in_low_padding = indices[i] < pad_low[i]; bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 196865f3332..a7a0044308d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -449,8 +449,7 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { Status AlgebraicSimplifierVisitor::HandleConcatenate( HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + absl::Span operands(concatenate->operands()); if (operands.size() == 1) { // Unary concatenates are useless. ReplaceInstructionIfSameShape(concatenate, operands[0]); @@ -588,7 +587,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { namespace { template Status InvertConstant(const HloInstruction& constant, Literal* result) { - return result->Populate([&](tensorflow::gtl::ArraySlice indices) { + return result->Populate([&](absl::Span indices) { return T{1.0} / constant.literal().Get(indices); }); } @@ -1249,8 +1248,7 @@ namespace { // // Precondition: input_dim_indices is sorted. absl::optional> ReshapeLeavesDimensionsUnmodified( - const HloInstruction* hlo, - tensorflow::gtl::ArraySlice input_dim_indices) { + const HloInstruction* hlo, absl::Span input_dim_indices) { CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); @@ -1853,7 +1851,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (ShapeUtil::IsZeroElementArray(arg->shape()) || ShapeUtil::IsZeroElementArray(reduce->shape())) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index cbce98ef131..182c581ad8d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2226,7 +2226,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { auto out_dims = in_dims; out_dims[in_channel_idx] = options.f_output_channels; - auto make_shape = [](tensorflow::gtl::ArraySlice dims, + auto make_shape = [](absl::Span dims, bool minor_to_major_layout) { if (minor_to_major_layout) { return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3}); @@ -2838,8 +2838,8 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { // a and b are parallel bounds we can either turn into a B F S0 S1 or // `B S0 S1 F` kind of pattern. - auto decorate_spatials = [¶m](tensorflow::gtl::ArraySlice spatials, - int64 a, int64 b) { + auto decorate_spatials = [¶m](absl::Span spatials, int64 a, + int64 b) { std::vector result; if (param.prepend_a) { result.push_back(a); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index a6889cb171b..5c180cbdd49 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -112,11 +112,11 @@ StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { return stream_pools_.at(executor).BorrowStream(executor); } -Backend::Backend( - se::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, ComputationPlacer* computation_placer, - int intra_op_parallelism_threads) +Backend::Backend(se::Platform* platform, Compiler* compiler, + absl::Span stream_executors, + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 4a6a78daf07..fdf8d9cab28 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -149,7 +149,7 @@ class Backend { private: struct EigenThreadPoolWrapper; Backend(se::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice stream_executors, + absl::Span stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index a6f77db3b02..b5cf245af64 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -69,8 +69,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // Inserts conversion HLOs to replace the called computations' BF16 // operands/outputs to F32. Status ConvertCalledComputations( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice bf16_called_comps); + HloInstruction* hlo, absl::Span bf16_called_comps); HloComputation* computation_; const BFloat16Support* bfloat16_support_; @@ -114,8 +113,7 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( } Status BFloat16NormalizationVisitor::ConvertCalledComputations( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice bf16_called_comps) { + HloInstruction* hlo, absl::Span bf16_called_comps) { std::map cloned_computations; for (auto& comp : bf16_called_comps) { auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone()); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 2fb401c4289..545a6ecfb1f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -407,7 +407,7 @@ void BFloat16Propagation::AdjustCalledComputationParameters( HloInstruction* hlo) { auto adjust_computation = [this, hlo](HloComputation* computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { // Adjust parameters. CHECK_EQ(operands.size(), computation->num_parameters()); for (int64 i = 0; i < operands.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index e9751cc2693..8bd15339724 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -118,7 +118,7 @@ class BufferAssignmentTest : public HloVerifiedTestBase { std::unique_ptr RunBufferAssignmentWithInstructionSequence( HloModule* module, - tensorflow::gtl::ArraySlice instruction_sequence, + absl::Span instruction_sequence, int64 alignment = 1) { SequentialHloOrdering::HloModuleSequence module_sequence; module_sequence[module->entry_computation()] = diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 3079695e967..e5a6c28478a 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -62,7 +62,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector> hlo_modules; diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 1ac950bdd66..61136a3e11f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -50,12 +50,12 @@ class CompileOnlyService : public Service { // |CompileOnlyClient::CompileAheadOfTime| for additional details. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options); StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 1b7a7b36eac..b65dfef9c95 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -479,7 +479,7 @@ class CopyRemover { // 'values' an entry is created in value_to_node which indicates the // respective ValueNode representing that value. void AddValueList( - tensorflow::gtl::ArraySlice values, + absl::Span values, tensorflow::gtl::FlatMap* value_to_node) { ValueNode* tail = nullptr; ValueNode* head = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc index 408fe0f5bf5..1942ea1a2af 100644 --- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc @@ -40,7 +40,7 @@ std::vector CreateBufferInfosFromBufferAssignment( } std::vector CreateArgIndexTableFromBufferInfos( - tensorflow::gtl::ArraySlice buffer_infos) { + absl::Span buffer_infos) { std::vector result; for (int64 i = 0; i < buffer_infos.size(); i++) { if (buffer_infos[i].is_entry_parameter()) { diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h index 05de70c7268..0c5a60f13f6 100644 --- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h @@ -34,7 +34,7 @@ CreateBufferInfosFromBufferAssignment( // If this function returns V then entry parameter i has buffer allocation index // V[i]. std::vector CreateArgIndexTableFromBufferInfos( - tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo> + absl::Span buffer_infos); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index bf2efc4d14d..9b00f2eaa57 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -77,7 +77,7 @@ StatusOr, std::vector>> CpuExecutable::CreateTempArray( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::vector unowning_buffers( assignment_->Allocations().size()); std::vector owning_buffers( @@ -136,7 +136,7 @@ CpuExecutable::CreateTempArray( Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice buffers, + absl::Span buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: // @@ -207,7 +207,7 @@ Status CpuExecutable::ExecuteComputeFunction( StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::MutableArraySlice buffers) { + absl::Span buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( /*on_host_shape=*/result_shape(), @@ -245,7 +245,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { TF_ASSIGN_OR_RETURN( auto result, @@ -256,7 +256,7 @@ StatusOr CpuExecutable::ExecuteOnStream( StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { if (hlo_profiling_enabled()) { return Unimplemented( "Asynchronous execution on stream with hlo profiling is not yet " @@ -267,7 +267,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { if (GetRootPointsToSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); @@ -299,7 +299,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( // // We also need to change the types of some of the variables we capture: // run_options needs to change from a pointer to a value type, and arguments - // needs to change from an ArraySlice into a vector. We use a struct instead + // needs to change from a Span into a vector. We use a struct instead // of a lambda to make this explicit. struct AsyncRunTask { CpuExecutable* executable; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 96e53de57ee..236de8f14f2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -57,12 +57,12 @@ class CpuExecutable : public Executable { StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -92,7 +92,7 @@ class CpuExecutable : public Executable { // exists) must out-live the task. StatusOr ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile); // Creates an array suitable for passing as the "temps" argument to the JIT @@ -112,21 +112,20 @@ class CpuExecutable : public Executable { StatusOr, std::vector>> CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile); + Status ExecuteComputeFunction(const ExecutableRunOptions* run_options, + absl::Span buffers, + HloExecutionProfile* hlo_execution_profile); // Creates a ScopedShapedBuffer for holding the result of the computation, // moving buffers out of allocated_buffers and into the result as appropriate. // The addresses are set according to buffer assignment. StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::MutableArraySlice buffers); + absl::Span buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 0df2abf0012..5519a43b2f6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -179,7 +179,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( + absl::Span dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); TF_ASSIGN_OR_RETURN( @@ -225,7 +225,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( StatusOr CpuTransferManager::TransferTupleBuffersFromOutfeed( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data) { + absl::Span> buffer_data) { return TransferBuffersFromOutfeedInternal(executor, buffer_data, /*is_tuple=*/true); } @@ -238,8 +238,7 @@ StatusOr CpuTransferManager::TransferArrayBufferFromOutfeed( StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data, - bool is_tuple) { + absl::Span> buffer_data, bool is_tuple) { std::vector> buffers; for (auto b : buffer_data) { int64 size = b.second; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 7b938e9fd7d..6927edff868 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -56,7 +56,7 @@ class CpuTransferManager : public GenericTransferManager { // Helper that transfers a tuple of element buffers from the device's outfeed. StatusOr TransferTupleBuffersFromOutfeed( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data); + absl::Span> buffer_data); // Helper that transfers an array buffer from the device's outfeed. StatusOr TransferArrayBufferFromOutfeed(se::StreamExecutor* executor, @@ -68,8 +68,7 @@ class CpuTransferManager : public GenericTransferManager { // for the given buffers. StatusOr TransferBuffersFromOutfeedInternal( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data, - bool is_tuple); + absl::Span> buffer_data, bool is_tuple); TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); }; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index dd060f54a29..99fa707c959 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -80,7 +80,7 @@ class MemoryTile { // `minor_dim_offset`}. // // Note: `major_dim_offset` is a parameter to the constructor. - void StoreTile(tensorflow::gtl::ArraySlice tile, + void StoreTile(absl::Span tile, llvm::Value* minor_dim_offset) const { CHECK_EQ(tile.size(), pointers_.size()); for (int64 i = 0; i < pointers_.size(); i++) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 1c828cc02c0..7839d973177 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -506,8 +506,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice elemental_operands, - absl::string_view name) { + absl::Span elemental_operands, absl::string_view name) { return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } @@ -1455,7 +1454,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, gtl::ArraySlice dimensions, + HloInstruction* arg, absl::Span dimensions, unsigned element_alignment) { ShardedVector accumulator; accumulator.reserve(accumulator_type.size()); @@ -1551,7 +1550,7 @@ void IrEmitter::EmitShardedVectorStore( StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - gtl::ArraySlice dimensions, HloComputation* function, + absl::Span dimensions, HloComputation* function, string* failure_reason) { if (!ReductionPreservesLayout(*reduce)) { return false; @@ -1701,7 +1700,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) { const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); - gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1758,7 +1757,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (!options::VectorizedReduceDisabled(hlo_module_config_)) { string vectorization_failure_reason; @@ -2113,7 +2112,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { } Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { - gtl::ArraySlice operands(custom_call->operands()); + absl::Span operands(custom_call->operands()); absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = @@ -2233,7 +2232,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { } StatusOr IrEmitter::EmitFastConcatenate( - HloInstruction* concatenate, gtl::ArraySlice operands, + HloInstruction* concatenate, absl::Span operands, string* failure_reason) { if (ShouldEmitParallelLoopFor(*concatenate)) { *failure_reason = @@ -2369,7 +2368,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { - gtl::ArraySlice operands(concatenate->operands()); + absl::Span operands(concatenate->operands()); string failure_reason; TF_ASSIGN_OR_RETURN( bool successful, @@ -2800,8 +2799,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, Status IrEmitter::ElementTypesSameAndSupported( const HloInstruction& instruction, - gtl::ArraySlice operands, - gtl::ArraySlice supported_types) { + absl::Span operands, + absl::Span supported_types) { for (auto operand : operands) { TF_RET_CHECK( ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); @@ -2831,8 +2830,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { } llvm::Value* IrEmitter::EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice parameters, + const HloComputation& callee, absl::Span parameters, absl::string_view name) { const Shape& return_shape = callee.root_instruction()->shape(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index f98891246b0..015724b65dc 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -111,7 +111,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit code to map one element according to `map_instr`. llvm::Value* EmitElementalMap( const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice elemental_operands, + absl::Span elemental_operands, absl::string_view name); protected: @@ -252,10 +252,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, // // `parameters` holds the *scalar values* that need to be passed to the // callee. The return value is the scalar returned by the callee. - llvm::Value* EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice parameters, - absl::string_view name); + llvm::Value* EmitThreadLocalCall(const HloComputation& callee, + absl::Span parameters, + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to @@ -271,8 +270,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // match and are of one of the given supported types. Status ElementTypesSameAndSupported( const HloInstruction& instruction, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice supported_types); + absl::Span operands, + absl::Span supported_types); // Emit IR to perform a computation for every element in the given target op. // This produces a series of nested loops (one for each dimension of the op's @@ -319,10 +318,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, // concepts that generalize over other vectorizable operations. We should // consider pulling out these abstractions into a VectorizingIrEmitter or // something similar. - StatusOr EmitVectorizedReduce( - HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, HloComputation* function, - string* failure_reason); + StatusOr EmitVectorizedReduce(HloInstruction* reduce, + HloInstruction* arg, + HloInstruction* init_value, + absl::Span dimensions, + HloComputation* function, + string* failure_reason); // We'd like to keep one or two one cache-line's worth of data in registers // without generating IR with illegal (e.g. excessively large or @@ -372,16 +373,15 @@ class IrEmitter : public DfsHloVisitorWithDefault, const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, + HloInstruction* arg, absl::Span dimensions, unsigned element_alignment); // Tries to emit a fast concatenate operation using memcpy. Returns true if // successful, and false on failure. On failure, sets "failure_reason" to a // string describing why it could not emit a fast concatenate. - StatusOr EmitFastConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands, - string* failure_reason); + StatusOr EmitFastConcatenate(HloInstruction* concatenate, + absl::Span operands, + string* failure_reason); // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" // from the address "source" to the address "target". diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 784045313df..3ecf4b69b7f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -200,10 +200,10 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // Returns an array of compute function call arguments (including parameter // address buffer). std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, absl::string_view name, - llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, - llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + absl::Span parameter_addresses, llvm::IRBuilder<>* b, + absl::string_view name, llvm::Value* return_value_buffer, + llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, + llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer; if (parameter_addresses.empty()) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index ee7595f6e97..076ca219bc7 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -115,10 +115,10 @@ class IrFunction { // Returns an array of compute function call argument ir values. std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, absl::string_view name, - llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, - llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + absl::Span parameter_addresses, llvm::IRBuilder<>* b, + absl::string_view name, llvm::Value* return_value_buffer, + llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, + llvm::Value* profile_counters_arg); // Emits a call to a runtime fork/join function which dispatches parallel // calls to 'parallel_function' (and joins threads before returning). diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 962ea69c094..1bd4b59dd60 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -428,7 +428,7 @@ std::vector TileVariable::Get() const { return result; } -void TileVariable::Set(tensorflow::gtl::ArraySlice value) { +void TileVariable::Set(absl::Span value) { CHECK_EQ(value.size(), storage_.size()); for (int64 i = 0, e = value.size(); i < e; i++) { storage_[i].Set(value[i]); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index c728f6df0ae..3dfe941a3ab 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -324,7 +324,7 @@ class TileVariable { std::vector initial_value); std::vector Get() const; - void Set(tensorflow::gtl::ArraySlice value); + void Set(absl::Span value); private: std::vector storage_; diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc index 47543b2082f..b9e47f5aade 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc @@ -37,7 +37,7 @@ void XfeedQueueManager::Reset() { } void XfeedQueueManager::EnqueueBuffersAtomically( - tensorflow::gtl::ArraySlice buffers) { + absl::Span buffers) { tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffers_.empty(); for (XfeedBuffer* b : buffers) { diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index b4ace232607..fac1722b107 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -63,8 +63,7 @@ class XfeedQueueManager { // called when the buffer will no longer be accessed by the XfeedManager, // either as a result of a call to Reset or because the runtime has dequeued // and used the buffer. - void EnqueueBuffersAtomically( - tensorflow::gtl::ArraySlice buffers); + void EnqueueBuffersAtomically(absl::Span buffers); // Blocks until the queue is non-empty, then returns the buffer at the head of // the queue. Sets the current buffer to be the returned buffer. It is an diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 1d0297cfbfc..edbcb252474 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -25,7 +25,7 @@ namespace xla { StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( const se::Platform* platform, - tensorflow::gtl::ArraySlice stream_executors) + absl::Span stream_executors) : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index d87b86caf0d..28a3539373e 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -80,7 +80,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { public: StreamExecutorMemoryAllocator( const se::Platform* platform, - tensorflow::gtl::ArraySlice stream_executors); + absl::Span stream_executors); StatusOr Allocate(int device_ordinal, uint64 size, bool retry_on_failure) override; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index def42f9c770..4bb1e071d8d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -856,7 +856,7 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, auto getFloat = [&](const float f) { return llvm::ConstantFP::get(b_->getFloatTy(), f); }; - auto multiply_add = [&](tensorflow::gtl::ArraySlice coefficients, + auto multiply_add = [&](absl::Span coefficients, llvm::Value* w) { llvm::Value* p = getFloat(coefficients.front()); coefficients.remove_prefix(1); @@ -893,7 +893,7 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, SetToFirstInsertPoint(if_data.true_block, b_); { llvm::Value* lw = FSub(w, getFloat(2.5f)); - tensorflow::gtl::ArraySlice lq{ + absl::Span lq{ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, -4.39150654e-06f, 0.00021858087f, -0.00125372503f, -0.00417768164f, 0.246640727f, 1.50140941f}; @@ -908,7 +908,7 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f)); - tensorflow::gtl::ArraySlice gq{ + absl::Span gq{ -0.000200214257f, 0.000100950558f, 0.00134934322f, -0.00367342844f, 0.00573950773f, -0.0076224613f, 0.00943887047f, 1.00167406f, 2.83297682f}; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 5ab07562194..1b3be199f63 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -28,8 +28,7 @@ using absl::nullopt; class ElementalIrEmitterExecutionTest : public HloTestBase { protected: - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 78edf918a4d..47c56e2f7fb 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -26,13 +26,12 @@ limitations under the License. #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" -using tensorflow::gtl::ArraySlice; namespace xla { StatusOr> Executable::ExecuteOnStreams( - ArraySlice run_options, - ArraySlice> arguments) { + absl::Span run_options, + absl::Span> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); std::vector return_values; @@ -63,7 +62,7 @@ StatusOr> Executable::ExecuteOnStreams( StatusOr Executable::ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - ArraySlice arguments) { + absl::Span arguments) { se::Stream* stream = run_options->stream(); std::unique_ptr timer; if (profile != nullptr) { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 6e055edc035..4b8d955b286 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -81,14 +81,14 @@ class Executable { // Returns a shaped buffer containing the result of the computation. virtual StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) = 0; // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) = 0; + absl::Span arguments) = 0; // Starts the given program executing on the given stream/executor. // @@ -119,11 +119,8 @@ class Executable { // run_options[i]->stream() and the returned value is at index i of the // returned vector. virtual StatusOr> ExecuteOnStreams( - tensorflow::gtl::ArraySlice - run_options, - tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> - arguments); + absl::Span run_options, + absl::Span> arguments); // Populates `hlo_execution_profile` from `executor`. This is implicit in any // Execute* API call that takes a hlo_execution_profile argument, but must be @@ -139,7 +136,7 @@ class Executable { // given ExecutionProfile if non-null. StatusOr ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Returns the ExecutionProfile from executing on the device. This includes // the number of cycles taken for the computation or the compilation time. diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 3f1a8813721..cb86c985793 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::ArraySlice; static StatusOr TransposeIndexVectorDimToLast( HloInstruction* start_indices, int64 index_vector_dim) { @@ -225,7 +224,7 @@ static StatusOr> GatherLoopBody( static StatusOr CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, - ArraySlice slice_sizes, int64 gather_loop_trip_count, + absl::Span slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { std::vector accumulator_state_shape_dims; accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); @@ -244,7 +243,7 @@ static StatusOr CreateGatherLoopAccumulatorInitValue( // are the major dimensions and the offset dimensions are the minor dimensions. // Fix this up with a transpose. static StatusOr PermuteBatchAndOffsetDims( - HloInstruction* accumulator, ArraySlice offset_dims, + HloInstruction* accumulator, absl::Span offset_dims, int64 output_rank) { std::vector permutation; permutation.reserve(output_rank); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 0ce2db907b6..4ed91ef1876 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -42,8 +42,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); @@ -163,7 +162,7 @@ Status GenericTransferManager::TransferLiteralFromOutfeed( } Status GenericTransferManager::ResetDevices( - tensorflow::gtl::ArraySlice + absl::Span /*executors*/) { return Unimplemented( "Device reset is not yet supported on this platform (b/30481585)"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 6c1a21587a7..86c8b1c145a 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -55,15 +55,13 @@ class GenericTransferManager : public TransferManager { const Shape& literal_shape, MutableBorrowingLiteral literal) override; - Status ResetDevices( - tensorflow::gtl::ArraySlice executors) override; + Status ResetDevices(absl::Span executors) override; int64 GetByteSizeRequirement(const Shape& shape) const override; protected: Status WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index dbdf8e7a0e9..2af31a52f9c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -204,9 +204,8 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( if (allocator_ != nullptr) { allocator = allocator_; } else { - se_allocator.emplace( - stream_exec_->platform(), - tensorflow::gtl::ArraySlice({stream_exec_})); + se_allocator.emplace(stream_exec_->platform(), + absl::Span({stream_exec_})); allocator = &*se_allocator; } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 57a3a43a6fa..c1aaa4bf04d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -74,10 +74,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter( compute_nested_(std::move(compute_nested)) {} StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // The libdevice math functions differentiate between "double" and "float" by // appending an 'f' to the function's name. libdevice doesn't have f16 math // functions, so we convert the operands to f32 before calling the function @@ -119,10 +117,8 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( } StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // llvm intrinsics differentiate between half/float/double functions via // the suffixes ".f16", ".f32" and ".f64". string munged_callee = callee_name; @@ -144,10 +140,8 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( } StatusOr GpuElementalIrEmitter::EmitMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { @@ -290,11 +284,9 @@ StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type, + absl::Span attributes) { std::vector ir_input_types; for (PrimitiveType input_type : input_types) { ir_input_types.push_back( diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 91942785d28..43f1f208bfa 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -38,9 +38,9 @@ namespace gpu { class GpuElementalIrEmitter : public ElementalIrEmitter { public: // A NestedComputer computes an element of the output of the given computation - // given an ArraySlice of its input elements. + // given a Span of its input elements. using NestedComputer = std::function( - const HloComputation&, tensorflow::gtl::ArraySlice)>; + const HloComputation&, absl::Span)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, llvm::Module* module, llvm::IRBuilder<>* b, @@ -96,37 +96,29 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { // Emits IR to call a device function named "callee_name" on the given // operand. Returns the IR value that represents the return value. llvm::Value* EmitDeviceFunctionCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_type, - PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes); + const string& callee_name, absl::Span operands, + absl::Span input_type, PrimitiveType output_type, + absl::Span attributes); // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. StatusOr EmitLlvmIntrinsicMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type); + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); // Emits IR to call a libdevice function of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. StatusOr EmitLibdeviceMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type); + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); // Emits IR to call a function of type [T] -> T. Does not munge callee_name. // Returns the IR value that represents the return value of the function. StatusOr EmitMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type); + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); const HloModuleConfig& hlo_module_config_; NestedComputer compute_nested_; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 11549cdac53..ca4a605af5d 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -92,8 +92,7 @@ string FftTypeToString(se::fft::Type type) { } // namespace -FftThunk::FftThunk(FftType fft_type, - tensorflow::gtl::ArraySlice fft_length, +FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& output_shape, diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 4adec7ee544..2be50e08bd2 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -62,7 +62,7 @@ class FftThunk : public Thunk { public: // Constructs a thunk for launching an FFT on a stream. // Semantics of null hlo_instruction argument are as in Thunk. - FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice fft_length, + FftThunk(FftType fft_type, absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& output_shape, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 71a02e70df7..31a9f9b1beb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -234,7 +234,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { DeviceMemoryAllocator* memory_allocator = run_options->allocator(); @@ -325,7 +325,7 @@ StatusOr GpuExecutable::ExecuteOnStream( StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on GPU."); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 627a05e2401..b3765adf5e5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -78,12 +78,12 @@ class GpuExecutable : public Executable { // match the compute capability passed to this object's constructor. StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; private: // If `block_host_until_done` is false, execution will not block the host diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 0e205b9c028..51627402b45 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -35,8 +35,8 @@ using absl::StrAppend; using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( - tensorflow::gtl::ArraySlice io_hlos, - tensorflow::gtl::ArraySlice non_io_hlos) { + absl::Span io_hlos, + absl::Span non_io_hlos) { // I/O HLOs are bound to the arguments of the current IR function. I.e., // // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index eee40b0e91f..5b05ed812ed 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -45,8 +45,8 @@ class HloToIrBindings { alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {} void EmitBasePointersForHlos( - tensorflow::gtl::ArraySlice io_hlos, - tensorflow::gtl::ArraySlice non_io_hlos); + absl::Span io_hlos, + absl::Span non_io_hlos); // Rebinds the given HLO to the LLVM IR value that represent its address. void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index f544bcc9197..9c90f4d46b3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -216,7 +216,7 @@ bool IsReductionToVector(const HloInstruction& reduce) { // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls llvm::Value* EmitPrintf(absl::string_view fmt, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, llvm::IRBuilder<>* builder) { std::vector argument_types; for (auto argument : arguments) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index a35e250101c..d242897e16b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -127,7 +127,7 @@ bool IsReductionToVector(const HloInstruction& reduce); // Emits call to "vprintf" with given format and arguments. llvm::Value* EmitPrintf(absl::string_view fmt, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, llvm::IRBuilder<>* builder); // Emits code to shuffle data between threads of a warp. This has the same diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index bdf6aadde67..ffca5d6549a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -141,7 +141,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { Status IrEmitter::EmitCallToNestedComputation( const HloComputation& nested_computation, - tensorflow::gtl::ArraySlice operands, llvm::Value* output) { + absl::Span operands, llvm::Value* output) { TF_RET_CHECK(nested_computation.num_parameters() > 0); llvm::Function*& emitted_function = computation_to_ir_function_[&nested_computation]; @@ -633,7 +633,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); return EmitTargetElementLoop( *reduce, @@ -748,7 +748,7 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { StatusOr IrEmitter::ComputeNestedElement( const HloComputation& computation, - tensorflow::gtl::ArraySlice parameter_elements) { + absl::Span parameter_elements) { llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType( computation.root_instruction()->shape().element_type(), module_), diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 3673b9f58d6..bc2b04ace58 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -143,9 +143,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emits a call in IR to the given nested computation with the given operands // and output. If no IR function has been previously emitted for the // computation, also emits such a function. - Status EmitCallToNestedComputation( - const HloComputation& nested_computation, - tensorflow::gtl::ArraySlice operands, llvm::Value* output); + Status EmitCallToNestedComputation(const HloComputation& nested_computation, + absl::Span operands, + llvm::Value* output); // Emits an atomic operation that implements `nested_computation` in the // sequentially consistent memory model. `output_address` and `source_address` @@ -199,7 +199,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr ComputeNestedElement( const HloComputation& computation, - tensorflow::gtl::ArraySlice parameter_elements); + absl::Span parameter_elements); // Emits an atomic operation that implements `nested_computation` in the // sequentially consistent memory model. `output_address` and `source_address` diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 860dd0b50f3..3ab79197e2a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -94,7 +94,6 @@ using absl::optional; using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; -using tensorflow::gtl::ArraySlice; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -176,7 +175,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice args) { + absl::Span args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( @@ -556,10 +555,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); std::vector> thunks; - ArraySlice output_instructions = + absl::Span output_instructions = root->opcode() == HloOpcode::kTuple ? root->operands() - : ArraySlice(&root, 1); + : absl::Span(&root, 1); // For multi-output fusion emit an initializer for each tuple element. // Otherwise it's sufficient to just initialize the single output. @@ -718,8 +717,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { Status IrEmitterUnnested::EmitExtraOutputsForReduce( const HloInstruction* reduce, const IrArray::Index& index, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { const HloInstruction* output = reduce->parent()->FusionInstruction(); @@ -736,12 +734,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; @@ -951,12 +948,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // Divide the input matrix into tiles of size KxL. For example, when the // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like @@ -1240,12 +1236,11 @@ static std::pair ComputeTilingSchemeForReduction( Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // A naive algorithm is: // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. @@ -1593,13 +1588,12 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for @@ -1694,7 +1688,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } auto input = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions_to_reduce(reduce->dimensions()); + absl::Span dimensions_to_reduce(reduce->dimensions()); HloComputation* reducer = reduce->to_apply(); // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that @@ -2570,7 +2564,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // Are all the bytes of this scalar equal to 0? If so, we can create a // MemzeroThunk. - ArraySlice literal_bytes( + absl::Span literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { return {absl::make_unique(GetAllocationSlice(*hlo, index), @@ -2880,7 +2874,7 @@ int IrEmitterUnnested::ConstructIrArrayForInputs( int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( const HloInstruction& hlo, const std::vector& output_arrays, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* output_reduced_shapes, std::vector* output_in_reduced_shape_arrays) { int64 num_outputs = 1; @@ -2907,7 +2901,7 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* param_reduced_shapes, std::vector* param_in_reduced_shape_arrays) { int64 num_params = hlo.operands().size(); @@ -3048,8 +3042,8 @@ void EmitTiledElementalCodeWithBoundsCheck( // TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient // to launch fewer blocks so each transposes many tiles. LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, tensorflow::gtl::ArraySlice reduced_output_dims, - tensorflow::gtl::ArraySlice tiled_param_ids) { + HloInstruction* hlo, absl::Span reduced_output_dims, + absl::Span tiled_param_ids) { // Parameters for the tiling algorithm. constexpr int64 kTileSize = 32; constexpr int64 kNumRows = 4; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 52544199079..084462330ed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -105,13 +105,12 @@ class IrEmitterUnnested : public IrEmitter { // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice args); + absl::Span args); // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span> extra_output_gens); // EmitColumnReduction and EmitRowReduction emit code for column and row @@ -127,12 +126,11 @@ class IrEmitterUnnested : public IrEmitter { Status EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a @@ -143,23 +141,21 @@ class IrEmitterUnnested : public IrEmitter { Status EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. Status EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which @@ -180,13 +176,12 @@ class IrEmitterUnnested : public IrEmitter { // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel @@ -195,10 +190,9 @@ class IrEmitterUnnested : public IrEmitter { // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and // returns the launch dimensions for the kernel. This is a helper to support // the implementation of CheckAndEmitHloWithTile021. - LaunchDimensions EmitHlo021Tile( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice reduced_output_dims, - tensorflow::gtl::ArraySlice tiled_param_ids); + LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, + absl::Span reduced_output_dims, + absl::Span tiled_param_ids); // Generates the IrArray for each output of hlo and returns the number of // outputs. int ConstructIrArrayForOutputs(const HloInstruction& hlo, @@ -214,7 +208,7 @@ class IrEmitterUnnested : public IrEmitter { int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( const HloInstruction& hlo, const std::vector& output_arrays, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* output_reduced_shapes, std::vector* output_in_reduced_shape_arrays); // For each input of the `hlo` instruction, checks its value in @@ -226,7 +220,7 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* param_reduced_shapes, std::vector* param_in_reduced_shape_arrays); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 3259eaa2a26..878b0b96a1d 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -27,10 +27,10 @@ limitations under the License. namespace xla { namespace gpu { -KernelThunk::KernelThunk( - tensorflow::gtl::ArraySlice args, - const string& kernel_name, const HloInstruction* hlo_instruction, - int unroll_factor) +KernelThunk::KernelThunk(absl::Span args, + const string& kernel_name, + const HloInstruction* hlo_instruction, + int unroll_factor) : Thunk(Kind::kKernel, hlo_instruction), args_(args.begin(), args.end()), kernel_name_(kernel_name), diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index d751de50ad6..480f473037c 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -47,7 +47,7 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(tensorflow::gtl::ArraySlice args, + KernelThunk(absl::Span args, const string& kernel_name, const HloInstruction* hlo_instruction, int unroll_factor); KernelThunk(const KernelThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index 79f7d31816b..fa84d772235 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -23,7 +23,6 @@ limitations under the License. namespace xla { namespace gpu { -using tensorflow::gtl::ArraySlice; // We want the input/output feature counts of an f16 conv to be factors of 8, // because without this cudnn can't use tensor cores on the conv. @@ -42,7 +41,7 @@ static constexpr double kMaxBytesTouchedIncrease = 1.2; // Pads the given dimensions in the given shape up to a multiple of // kDesiredNumFeaturesFactor. -static Shape PadShape(Shape s, ArraySlice dims) { +static Shape PadShape(Shape s, absl::Span dims) { for (int64 dim : dims) { int64 dim_to_pad_size = s.dimensions(dim); int64 new_dim_to_pad_size = diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index ca57cacb983..8154d75d23a 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -40,7 +40,7 @@ ParallelLoopEmitter::ParallelLoopEmitter( ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, + absl::Span target_arrays, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) : LoopEmitter(target_element_generator, target_arrays, b), diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index cc7da2e73b6..f32ea1ce4c4 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -47,11 +47,10 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { // // This is used in multi-output fusion. target_element_generator should // produce a struct with N elements, one for each of target_arrays. - ParallelLoopEmitter( - const llvm_ir::ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, - int unroll_factor = 1); + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + absl::Span target_arrays, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* b, int unroll_factor = 1); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 2d5735d6c40..a3a03b53f8a 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -34,8 +34,7 @@ namespace gpu { // issue (b/31336476). class TupleThunk : public Thunk { public: - TupleThunk(tensorflow::gtl::ArraySlice - tuple_element_buffers, + TupleThunk(absl::Span tuple_element_buffers, const BufferAllocation::Slice& dest_buffer, const HloInstruction* hlo_instruction) : Thunk(Kind::kTuple, hlo_instruction), diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h index 4873463b2ea..a88c87e46c8 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.h +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -84,7 +84,7 @@ class HloBuffer { return a->id() == b->id(); } - HloBuffer(Id id, tensorflow::gtl::ArraySlice values) + HloBuffer(Id id, absl::Span values) : id_(id), values_(values.begin(), values.end()) {} // Return the unique identifier for this HloBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c2d0673f491..fe7f2be888d 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -558,7 +558,7 @@ HloComputation::CreateFromProto( } void HloComputation::FuseInstructionsInto( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction* fusion_instruction) { CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); HloInstruction* root = instructions_to_fuse.front(); @@ -577,7 +577,7 @@ void HloComputation::FuseInstructionsInto( } HloInstruction* HloComputation::CreateFusionInstruction( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind) { HloInstruction* root = instructions_to_fuse.front(); HloInstruction* fusion_instruction = AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 59016624f76..daafb711fd5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -237,7 +237,7 @@ class HloComputation { // removed if they have no uses after fusion (this is necessarily true for at // least the root). HloInstruction* CreateFusionInstruction( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind); // Create a deep copy of the given instruction and return the instruction @@ -385,7 +385,7 @@ class HloComputation { // // Pre-condition: fusion_instruction's opcode is kFusion. void FuseInstructionsInto( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction* fusion_instruction); // Internal helper for recursive copying of an instruction. Creates and diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 7cd1481a8ad..07cd1efc120 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -105,8 +105,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { TEST_F(HloConstantFoldingTest, Concatenate) { const struct TestConfig { int concat_dimension; - tensorflow::gtl::ArraySlice dimensions; - tensorflow::gtl::ArraySlice concat_sizes; + absl::Span dimensions; + absl::Span concat_sizes; } test_configs[] = { {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, @@ -196,7 +196,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; bool matched = true; root->literal().EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); matched = matched && (value == literal_clone->Get(rindexes)); }); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 131846794d9..19ffb465c04 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -24,7 +24,6 @@ limitations under the License. namespace xla { using absl::StrCat; -using tensorflow::gtl::ArraySlice; StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { @@ -50,9 +49,9 @@ StatusOr MakePadHlo(HloInstruction* operand, } StatusOr MakeSliceHlo(HloInstruction* operand, - ArraySlice start_indices, - ArraySlice limit_indices, - ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( operand->shape(), start_indices, @@ -74,7 +73,7 @@ StatusOr MakeConvolveHlo( } StatusOr MakeTransposeHlo(HloInstruction* operand, - ArraySlice dimensions) { + absl::Span dimensions) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN( Shape transpose_shape, @@ -91,15 +90,15 @@ StatusOr MakeReshapeHlo(const Shape& result_shape, } StatusOr MakeReshapeHlo( - ArraySlice result_shape_dim_bounds, HloInstruction* operand) { + absl::Span result_shape_dim_bounds, HloInstruction* operand) { Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_dim_bounds); return MakeReshapeHlo(new_shape, operand); } -StatusOr MakeDynamicSliceHlo(HloInstruction* operand, - HloInstruction* start_indices, - ArraySlice slice_sizes) { +StatusOr MakeDynamicSliceHlo( + HloInstruction* operand, HloInstruction* start_indices, + absl::Span slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); TF_ASSIGN_OR_RETURN( @@ -125,8 +124,8 @@ StatusOr MakeDynamicUpdateSliceHlo( } StatusOr MakeBroadcastHlo( - HloInstruction* operand, ArraySlice broadcast_dimensions, - ArraySlice result_shape_bounds) { + HloInstruction* operand, absl::Span broadcast_dimensions, + absl::Span result_shape_bounds) { HloComputation* computation = operand->parent(); Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_bounds); @@ -146,8 +145,8 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, HloInstruction::CreateGetTupleElement(gte_shape, operand, index)); } -StatusOr MakeConcatHlo(ArraySlice operands, - int64 dimension) { +StatusOr MakeConcatHlo( + absl::Span operands, int64 dimension) { CHECK_GT(operands.size(), 0); HloComputation* computation = operands[0]->parent(); @@ -176,9 +175,8 @@ StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); } -StatusOr MakeMapHlo( - tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation) { +StatusOr MakeMapHlo(absl::Span operands, + HloComputation* map_computation) { CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; HloComputation* computation = operands.front()->parent(); std::vector operand_shapes; @@ -235,7 +233,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, } StatusOr ExpandFirstDimIntoNDims( - HloInstruction* operand, ArraySlice expanded_dims) { + HloInstruction* operand, absl::Span expanded_dims) { CHECK_GT(operand->shape().dimensions_size(), 0); CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims)); @@ -251,8 +249,8 @@ StatusOr ExpandFirstDimIntoNDims( return MakeReshapeHlo(new_shape, operand); } -StatusOr ElideDegenerateDims(HloInstruction* operand, - ArraySlice dims_to_elide) { +StatusOr ElideDegenerateDims( + HloInstruction* operand, absl::Span dims_to_elide) { CHECK(absl::c_is_sorted(dims_to_elide)); const Shape& input_shape = operand->shape(); @@ -277,7 +275,7 @@ StatusOr ElideDegenerateDims(HloInstruction* operand, } StatusOr InsertDegenerateDims( - HloInstruction* operand, ArraySlice dims_to_insert) { + HloInstruction* operand, absl::Span dims_to_insert) { CHECK(absl::c_is_sorted(dims_to_insert)); const Shape& operand_shape = operand->shape(); @@ -327,7 +325,7 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, - ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( absl::make_unique(LiteralUtil::Zero(element_type)))); @@ -336,7 +334,7 @@ StatusOr BroadcastZeros( } StatusOr> CreateComputationWithSignature( - ArraySlice domain, const Shape& range, + absl::Span domain, const Shape& range, absl::string_view name) { HloComputation::Builder b{string(name)}; int64 param_idx = 0; diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 1bc6d09b450..a1c4b374d11 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -40,10 +40,10 @@ StatusOr MakePadHlo(HloInstruction* operand, // Creates a slice HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeSliceHlo( - HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); +StatusOr MakeSliceHlo(HloInstruction* operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). @@ -53,8 +53,8 @@ StatusOr MakeConvolveHlo( // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeTransposeHlo( - HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); +StatusOr MakeTransposeHlo(HloInstruction* operand, + absl::Span dimensions); // Creates a reshape HLO instruction and adds it to the computation containing // `operand`. @@ -62,15 +62,14 @@ StatusOr MakeReshapeHlo(const Shape& result_shape, HloInstruction* operand); StatusOr MakeReshapeHlo( - tensorflow::gtl::ArraySlice result_shape_dim_bounds, - HloInstruction* operand); + absl::Span result_shape_dim_bounds, HloInstruction* operand); // Creates a dynamic-slice HLO instruction and adds it to the computation // containing `operand` and `start_indices` (`operand` and `start_indices` must // be in the same computation). StatusOr MakeDynamicSliceHlo( HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Creates a dynamic-update-slice HLO instruction and adds it to the computation // containing `operand`, `update` and `start_indices` (`operand`, `update` and @@ -82,9 +81,8 @@ StatusOr MakeDynamicUpdateSliceHlo( // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. StatusOr MakeBroadcastHlo( - HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions, - tensorflow::gtl::ArraySlice result_shape_bounds); + HloInstruction* operand, absl::Span broadcast_dimensions, + absl::Span result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -95,7 +93,7 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, // containing `operands` (`operands` must be non-empty and every element must be // contained in the same computation). StatusOr MakeConcatHlo( - tensorflow::gtl::ArraySlice operands, int64 dimension); + absl::Span operands, int64 dimension); // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). @@ -104,9 +102,8 @@ StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. -StatusOr MakeMapHlo( - tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation); +StatusOr MakeMapHlo(absl::Span operands, + HloComputation* map_computation); // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of @@ -138,7 +135,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, // For instance if `operand` has shape f32[200,9,7] and expanded_dims is // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7]. StatusOr ExpandFirstDimIntoNDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); + HloInstruction* operand, absl::Span expanded_dims); // Elides (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_elide` from `operand`. Every dimension in @@ -148,7 +145,7 @@ StatusOr ExpandFirstDimIntoNDims( // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9]. StatusOr ElideDegenerateDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); + HloInstruction* operand, absl::Span dims_to_elide); // Inserts (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_insert` into `operand`. The dimensions in @@ -158,7 +155,7 @@ StatusOr ElideDegenerateDims( // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. StatusOr InsertDegenerateDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_insert); + HloInstruction* operand, absl::Span dims_to_insert); // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. @@ -171,12 +168,12 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, // broadcast instruction is emitted into `computation`. StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. StatusOr> CreateComputationWithSignature( - tensorflow::gtl::ArraySlice domain, const Shape& range, + absl::Span domain, const Shape& range, absl::string_view name); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 662f0082053..eb6affadc80 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -24,15 +24,13 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::ArraySlice; class HloCreationUtilsTest : public HloVerifiedTestBase { protected: - HloModule* CreateModuleWithProgramShape(PrimitiveType primitive_type, - ArraySlice input_shape_dims, - ArraySlice output_shape_dims, - HloInstruction** param, - HloComputation** entry_computation) { + HloModule* CreateModuleWithProgramShape( + PrimitiveType primitive_type, absl::Span input_shape_dims, + absl::Span output_shape_dims, HloInstruction** param, + HloComputation** entry_computation) { Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 3376d170e64..6a63681996b 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -46,8 +46,7 @@ namespace { // // In this case, we should be able to reuse p0 and output, although p0 has // multiple uses. -bool MultiDynamicSliceUseShareSameIndices( - tensorflow::gtl::ArraySlice uses) { +bool MultiDynamicSliceUseShareSameIndices(absl::Span uses) { if (uses.empty()) { return false; } @@ -221,7 +220,7 @@ string HloDataflowAnalysis::ToString() const { bool HloDataflowAnalysis::Phi( HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs) { + absl::Span inputs) { CHECK(ssa_form_); VLOG(4) << "Phi(" << instruction->name() << ")"; VLOG(5) << "instruction value set = " diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index a1678d4943c..6d5c375d6d7 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -202,7 +202,7 @@ class HloDataflowAnalysis { // the given instruction. If skip_top_level is true, then the top level of the // value set of 'instruction' is not modified. bool Phi(HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs); + absl::Span inputs); // Updates the positions of the HloValues in the output of the given // instruction. This should be called after the instruction value set of diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index c25869f87b8..d316645a0b2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -53,7 +53,6 @@ namespace xla { namespace { -using tensorflow::gtl::ArraySlice; template StatusOr> Compare(const Shape& shape, HloOpcode opcode, @@ -97,10 +96,11 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, } auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -127,10 +127,11 @@ StatusOr> Compare( } auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -194,7 +195,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) template StatusOr> HloEvaluator::Evaluate( - const HloModule& module, ArraySlice arg_literals) { + const HloModule& module, absl::Span arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); evaluated_.clear(); @@ -211,7 +212,8 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - const HloComputation& computation, ArraySlice arg_literals) { + const HloComputation& computation, + absl::Span arg_literals) { CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); @@ -228,7 +230,7 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction, ArraySlice arg_literals) { + HloInstruction* instruction, absl::Span arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); evaluated_.clear(); @@ -390,7 +392,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { } Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { - ArraySlice operands(concatenate->operands()); + absl::Span operands(concatenate->operands()); // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); @@ -588,7 +590,7 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( // Return an ShapeUtil::IndexIterationSpace that iterates over the output slice // dimensions while keeping the rest of the output dimensions clamped to 0. ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( - int64 output_rank, ArraySlice slice_sizes, + int64 output_rank, absl::Span slice_sizes, const GatherDimensionNumbers& dim_numbers) { std::vector index_base(output_rank, 0); std::vector index_count(output_rank, 1); @@ -661,11 +663,12 @@ class OutputBatchIndexToInputIndex { // same storage for all invocations. // // This returns an arrayslice into memory owned by the class. - StatusOr> operator()(ArraySlice output_index) { + StatusOr> operator()( + absl::Span output_index) { PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); TF_RETURN_IF_ERROR(FetchIndexVector()); PropagateIndexVectorToInputIndex(); - return ArraySlice(input_index_); + return absl::Span(input_index_); } private: @@ -674,7 +677,7 @@ class OutputBatchIndexToInputIndex { // update the dim_numbers.index_vector_dim() dimension -- that's the dimension // we iterate over in FetchIndexVector. void PropagateOutputIndexGatherDimsToIndexVectorIndex( - ArraySlice output_index) { + absl::Span output_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = output_index.size(); i < e; i++) { if (!output_dim_is_batch_dims_[i]) { @@ -729,7 +732,7 @@ class OutputBatchIndexToInputIndex { // The index vector fetched from start_indices_. std::vector index_vector_; - // The result computed by this functor. operator() returns an ArraySlice into + // The result computed by this functor. operator() returns a Span into // this vector. std::vector input_index_; @@ -779,9 +782,10 @@ class OutputOffsetIndexToInputIndex { // result (input_index_), mutating it in place. // // This returns an arrayslice into memory owned by the class. - StatusOr> operator()(ArraySlice output_index) { + StatusOr> operator()( + absl::Span output_index) { PropagateOutputIndexWindowDimsToInputIndex(output_index); - return ArraySlice(input_index_); + return absl::Span(input_index_); } // Returns for a given 'input_dim' the corresponding output dimension index, @@ -794,7 +798,7 @@ class OutputOffsetIndexToInputIndex { // Propagates window dimensions from the output index to input_index_ by // mutating input_index_ in place. void PropagateOutputIndexWindowDimsToInputIndex( - ArraySlice output_index) { + absl::Span output_index) { for (int64 i = 0, e = input_index_.size(); i < e; i++) { if (input_dim_value_to_output_index_[i] != -1) { input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; @@ -810,7 +814,7 @@ class OutputOffsetIndexToInputIndex { // PropagateOutputIndexWindowDimsToInputIndex. std::vector input_dim_value_to_output_index_; - // The result computed by this functor. operator() returns an ArraySlice into + // The result computed by this functor. operator() returns a Span into // this vector. std::vector input_index_; }; @@ -872,11 +876,11 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { const Shape& operand_shape = operand.shape(); auto gather_inner_loop_body = - [&](ArraySlice output_window_index, - ArraySlice input_gather_index, - ArraySlice output_gather_index) -> StatusOr { + [&](absl::Span output_window_index, + absl::Span input_gather_index, + absl::Span output_gather_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - ArraySlice input_window_index, + absl::Span input_window_index, output_offset_index_to_input_index(output_window_index)); for (int i = 0, e = output_index.size(); i < e; i++) { output_index[i] = output_gather_index[i] + output_window_index[i]; @@ -909,8 +913,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { }; auto gather_outer_loop_body = - [&](ArraySlice output_gather_index) -> StatusOr { - TF_ASSIGN_OR_RETURN(ArraySlice input_gather_index, + [&](absl::Span output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN(absl::Span input_gather_index, output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( shape, offset_indices_iteration_space, @@ -1170,12 +1174,11 @@ StatusOr> EvaluateSortInternal( result_values.push_back(key_value.second); } auto result_keys_literal = absl::make_unique(keys_literal.shape()); - result_keys_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_keys)); + result_keys_literal->PopulateR1(absl::Span(result_keys)); auto result_values_literal = absl::make_unique(values_literal.shape()); result_values_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_values)); + absl::Span(result_values)); return std::make_pair(std::move(result_keys_literal), std::move(result_values_literal)); }; @@ -1311,26 +1314,27 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { // Explicit instantiation of templatized Evaluate* methods. // template StatusOr> -HloEvaluator::Evaluate(const HloModule& module, - ArraySlice arg_literals); +HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals); template StatusOr> HloEvaluator::Evaluate>( - const HloModule& module, ArraySlice> arg_literals); + const HloModule& module, + absl::Span> arg_literals); -template StatusOr> -HloEvaluator::Evaluate(const HloComputation& computation, - ArraySlice arg_literals); +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(const HloComputation& computation, + absl::Span arg_literals); template StatusOr> HloEvaluator::Evaluate>( const HloComputation& computation, - ArraySlice> arg_literals); + absl::Span> arg_literals); template StatusOr> -HloEvaluator::Evaluate(HloInstruction* instruction, - ArraySlice arg_literals); +HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals); template StatusOr> HloEvaluator::Evaluate>( HloInstruction* instruction, - ArraySlice> arg_literals); + absl::Span> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 980a7fb9fa5..3feb4e626f6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -51,8 +51,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // type. template StatusOr> Evaluate( - const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); + const HloModule& module, absl::Span arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -75,7 +74,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { template StatusOr> Evaluate( const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); + absl::Span arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -87,8 +86,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // type. template StatusOr> Evaluate( - HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); + HloInstruction* instruction, absl::Span arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. @@ -229,8 +227,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { } auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index e3eb60a8518..626daa527b9 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -60,7 +60,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, } std::unique_ptr Evaluate( - tensorflow::gtl::ArraySlice arg_literals = {}) { + absl::Span arg_literals = {}) { if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); @@ -344,7 +344,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; result->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); }); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index d35163ebb8d..980e3430359 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -97,7 +97,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> double GetAsDouble(const Literal& literal, - tensorflow::gtl::ArraySlice input_index) { + absl::Span input_index) { return static_cast(literal.Get(input_index)); } @@ -109,7 +109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> double GetAsDouble(const Literal& literal, - tensorflow::gtl::ArraySlice input_index) { + absl::Span input_index) { LOG(FATAL) << "Trying to get complex literal as double: " << literal.ToString(); } @@ -980,8 +980,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto result = absl::make_unique(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice out_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -1048,8 +1048,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data]( - tensorflow::gtl::ArraySlice out_index) { + rhs_literal_data](absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1211,8 +1210,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } auto result = absl::make_unique(dot->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice result_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span result_index) { ElementwiseT result_val = static_cast(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1261,9 +1260,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); auto result = absl::make_unique(pad->shape()); TF_RETURN_IF_ERROR(result->Populate( - [&scalar](tensorflow::gtl::ArraySlice multi_index) { - return scalar; - })); + [&scalar](absl::Span multi_index) { return scalar; })); const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); @@ -1276,7 +1273,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // corresponding index of the resulting padded literal. const PaddingConfig& pad_config = pad->padding_config(); - auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto func = [&](absl::Span input_index) { for (auto i = 0; i < input_index.size(); ++i) { // Interior padding occurs logically before edge padding, so in the case // of negative edge padding elements are removed from the @@ -1427,8 +1424,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result = absl::make_unique(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { std::vector> arg_literals; arg_literals.reserve(operands.size()); @@ -1539,8 +1536,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return SafeLess(a, b); }); auto result_literal = absl::make_unique(keys_literal.shape()); - result_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_data)); + result_literal->PopulateR1(absl::Span(result_data)); VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); return result_literal; }; @@ -1582,7 +1578,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { HloReduceInstruction* reduce = Cast(hlo); int64 num_args = reduce->inputs().size(); bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape()); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); absl::InlinedVector operand_shapes; @@ -1650,7 +1646,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { for (int64 input = 0; input < num_args; ++input) { TF_RETURN_IF_ERROR(results[input]->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + [&](absl::Span multi_index) { if (!eval_status.ok()) { return init_scalars[input]; } @@ -1668,7 +1664,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { IsScalarAdd(function)) { CHECK_EQ(num_args, 1); double computed_result = 0; - auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto func = [&](absl::Span input_index) { computed_result += GetAsDouble(*arg_literals[0], input_index); return true; @@ -1677,8 +1673,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_dim_counts, arg_dim_steps, func); return static_cast(computed_result); } - auto func = [&](tensorflow::gtl::ArraySlice input_index) - -> StatusOr { + auto func = + [&](absl::Span input_index) -> StatusOr { absl::InlinedVector arg_values(num_args); for (int64 i = 0; i < num_args; ++i) { arg_values[i] = arg_literals[i]->Get(input_index); @@ -1767,9 +1763,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { - return init_scalar; - })); + [&](absl::Span output_index) { return init_scalar; })); std::vector window_dimension_sizes; for (const auto& window_dimension : window.dimensions()) { @@ -1902,8 +1896,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); auto result = absl::make_unique(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -2049,12 +2043,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // same storage for all invocations. // // This returns an arrayslice into memory owned by the class. - StatusOr> operator()( - tensorflow::gtl::ArraySlice update_index) { + StatusOr> operator()( + absl::Span update_index) { PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); TF_RETURN_IF_ERROR(FetchIndexVector()); PropagateIndexVectorToInputIndex(); - return tensorflow::gtl::ArraySlice(input_index_); + return absl::Span(input_index_); } private: @@ -2063,7 +2057,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // update the dim_numbers.index_vector_dim() dimension -- that's the // dimension we iterate over in FetchIndexVector. void PropagateUpdateIndexScatterDimsToIndexVectorIndex( - tensorflow::gtl::ArraySlice update_index) { + absl::Span update_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = update_index.size(); i < e; i++) { if (!update_dim_is_scatter_dims_[i]) { @@ -2118,7 +2112,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // The index vector fetched from scatter_indices_. std::vector index_vector_; - // The result computed by this functor. operator() returns an ArraySlice + // The result computed by this functor. operator() returns a Span // into this vector. std::vector input_index_; @@ -2172,10 +2166,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // result (input_index_), mutating it in place. // // This returns an arrayslice into memory owned by the class. - StatusOr> operator()( - tensorflow::gtl::ArraySlice update_index) { + StatusOr> operator()( + absl::Span update_index) { PropagateUpdateIndexWindowDimsToInputIndex(update_index); - return tensorflow::gtl::ArraySlice(input_index_); + return absl::Span(input_index_); } // Returns for a given 'input_dim' the corresponding update dimension index, @@ -2188,7 +2182,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Propagates window dimensions from the update index to input_index_ by // mutating input_index_ in place. void PropagateUpdateIndexWindowDimsToInputIndex( - tensorflow::gtl::ArraySlice update_index) { + absl::Span update_index) { for (int64 i = 0, e = input_index_.size(); i < e; i++) { if (input_dim_value_to_update_index_[i] != -1) { input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; @@ -2204,7 +2198,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // PropagateUpdateIndexWindowDimsToInputIndex. std::vector input_dim_value_to_update_index_; - // The result computed by this functor. operator() returns an ArraySlice + // The result computed by this functor. operator() returns a Span // into this vector. std::vector input_index_; }; @@ -2247,12 +2241,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::unique_ptr result = operand.CloneToUnique(); HloEvaluator embedded_evaluator; auto scatter_inner_loop_body = - [&](tensorflow::gtl::ArraySlice update_window_index, - tensorflow::gtl::ArraySlice input_scatter_index, - tensorflow::gtl::ArraySlice update_scatter_index) - -> StatusOr { + [&](absl::Span update_window_index, + absl::Span input_scatter_index, + absl::Span update_scatter_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - tensorflow::gtl::ArraySlice input_window_index, + absl::Span input_window_index, update_window_index_to_input_index(update_window_index)); for (int i = 0, e = update_index.size(); i < e; i++) { update_index[i] = update_scatter_index[i] + update_window_index[i]; @@ -2301,14 +2294,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { }; auto scatter_outer_loop_body = - [&](tensorflow::gtl::ArraySlice update_scatter_index) - -> StatusOr { + [&](absl::Span update_scatter_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - tensorflow::gtl::ArraySlice input_scatter_index, + absl::Span input_scatter_index, update_scatter_index_to_input_index(update_scatter_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( updates_shape, window_indices_iteration_space, - [&](tensorflow::gtl::ArraySlice update_window_index) { + [&](absl::Span update_window_index) { return scatter_inner_loop_body( update_window_index, input_scatter_index, update_scatter_index); })); @@ -2336,7 +2328,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 rank = ShapeUtil::Rank(operand->shape()); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { + auto func = [&](absl::Span out_index) { DimensionVector operand_index(rank); for (int64 i = 0; i < rank; ++i) { operand_index[i] = @@ -2607,7 +2599,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // bound, call `f` with the base index. static void IterateThroughWindow( const Shape& window_shape, const Window& window, const Shape& base_shape, - const tensorflow::gtl::ArraySlice& window_count_index, + const absl::Span& window_count_index, const std::function&)>& f) { const int64 rank = ShapeUtil::Rank(base_shape); DimensionVector window_index(rank); @@ -2647,8 +2639,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector operand_indices(start.size()); auto result = absl::make_unique(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2679,7 +2671,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector result_index(rank, 0); - auto func = [&](tensorflow::gtl::ArraySlice update_index) { + auto func = [&](absl::Span update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); result->Set(result_index, @@ -2733,8 +2725,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2770,8 +2762,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index b747a4ea5f9..bd0b6af10d6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -113,7 +113,7 @@ StatusOr> HloInstruction::CreateFromProto( std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), - tensorflow::gtl::ArraySlice(fft_length)); + absl::Span(fft_length)); break; } case HloOpcode::kSend: @@ -519,13 +519,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters) { + absl::Span parameters) { return absl::make_unique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { if (opcode == HloOpcode::kCopy) { // It is impossible to copy an opaque shape, we don't know how big it is. CHECK(!ShapeUtil::IsOpaque(shape)); @@ -627,13 +627,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateVariadic( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { CHECK_EQ(HloOpcode::kTuple, opcode); return CreateNary(shape, opcode, operands); } /* static */ std::unique_ptr HloInstruction::CreateMap( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* map_computation) { return absl::make_unique(shape, operands, map_computation); } @@ -648,7 +648,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) { + absl::Span fft_length) { return absl::make_unique(shape, operand, fft_type, fft_length); } @@ -692,7 +692,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id) { @@ -702,7 +702,7 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr HloInstruction::CreateAllToAll( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, const std::vector& replica_groups) { return absl::make_unique(shape, operands, replica_groups); @@ -764,12 +764,12 @@ HloInstruction::CreateCollectivePermute( /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateAfterAll( - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { CHECK(!operands.empty()); auto instruction = absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); @@ -815,16 +815,15 @@ HloInstruction::CreateCollectivePermute( /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, absl::Span strides) { return absl::make_unique(shape, operand, start_indices, limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return absl::make_unique( shape, operand, start_indices, slice_sizes); } @@ -843,7 +842,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension) { return absl::make_unique(shape, operands, dimension); @@ -868,7 +867,7 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) { auto instruction = absl::WrapUnique(new HloReduceInstruction( shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); @@ -876,9 +875,9 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateReduce( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span operands, + absl::Span init_values, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) { std::vector all_args; all_args.reserve(operands.size() * 2); @@ -936,7 +935,7 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return absl::make_unique(shape, operand, broadcast_dimensions); } @@ -1014,7 +1013,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return absl::make_unique(shape, operand, dimensions); } @@ -1032,7 +1031,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation) { return absl::make_unique(shape, fusion_kind, operands, fusion_computation); @@ -1090,7 +1089,7 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* computation) { std::unique_ptr instruction = absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); @@ -1102,14 +1101,14 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, absl::string_view custom_call_target) { return absl::make_unique(shape, operands, custom_call_target); } /* static */ std::unique_ptr HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (auto element : elements) { element_shapes.push_back(element->shape()); @@ -1121,7 +1120,7 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateGather( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return absl::make_unique( shape, operand, start_indices, gather_dim_numbers, slice_sizes); } @@ -1149,8 +1148,7 @@ bool HloInstruction::HasSideEffect() const { } std::unique_ptr HloInstruction::CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; @@ -1501,7 +1499,7 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { } void HloInstruction::RemoveOperandsAtAscendingIndices( - tensorflow::gtl::ArraySlice ascending_indices) { + absl::Span ascending_indices) { if (ascending_indices.empty()) { return; } @@ -1997,7 +1995,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { string operands; - tensorflow::gtl::ArraySlice slice(operands_); + absl::Span slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; if (options.compact_operands() && slice.size() > kMaxOperandsToShowIfCompact) { @@ -3310,7 +3308,7 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { return Cast(this)->gather_dimension_numbers(); } -tensorflow::gtl::ArraySlice HloInstruction::gather_slice_sizes() const { +absl::Span HloInstruction::gather_slice_sizes() const { return Cast(this)->gather_slice_sizes(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f3fd287d881..88cb5d8acfd 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -365,7 +365,7 @@ class HloInstruction { // random numbers from a given distribution. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters); + absl::Span parameters); // Creates a unary instruction (one operand). // Precondition: opcode must be a legitimate unary operation. @@ -392,13 +392,13 @@ class HloInstruction { // Precondition: opcode must be a legitimate variadic operation. static std::unique_ptr CreateVariadic( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, // at a given index) static std::unique_ptr CreateMap( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter @@ -412,7 +412,7 @@ class HloInstruction { // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. @@ -449,7 +449,7 @@ class HloInstruction { // // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id); @@ -468,7 +468,7 @@ class HloInstruction { // be concatenated in the order of 1, 2, 3; another Alltoall will be applied // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. static std::unique_ptr CreateAllToAll( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, const std::vector& replica_groups); // Creates a communitation instructions that permutes data cross replicas. @@ -536,17 +536,15 @@ class HloInstruction { // start/limit indices. static std::unique_ptr CreateSlice( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, absl::Span strides); // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specified in // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + HloInstruction* start_indices, absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. @@ -557,7 +555,7 @@ class HloInstruction { // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. static std::unique_ptr CreateConcatenate( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) @@ -569,7 +567,7 @@ class HloInstruction { // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // A more general, multiple-argument version of the above. @@ -584,9 +582,9 @@ class HloInstruction { // ... // TODO(b/112040122): Add support to this in HLO passes and in backends. static std::unique_ptr CreateReduce( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span operands, + absl::Span init_values, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // Creates a reduce-window instruction, where the computation (given @@ -623,7 +621,7 @@ class HloInstruction { // Creates a broadcast instruction. static std::unique_ptr CreateBroadcast( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Creates a sequence of instructions that performs an explicit broadcast of // the operand to the target shape. @@ -653,7 +651,7 @@ class HloInstruction { // Creates a transpose instruction which permutes the operand dimensions. static std::unique_ptr CreateTranspose( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a sort op, with a keys operand, and an optional values operand. static std::unique_ptr CreateSort( @@ -679,7 +677,7 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); static std::unique_ptr CreateScatter( const Shape& shape, HloInstruction* operand, @@ -703,37 +701,37 @@ class HloInstruction { static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation); // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* computation); // Creates a custom call instruction that applies the given custom call target // to the given operands. "shape" is the resultant shape. static std::unique_ptr CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, absl::string_view custom_call_target); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( - tensorflow::gtl::ArraySlice elements); + absl::Span elements); // Creates a reverse instruction, which reverses the order of the elements // in the specified dimensions. static std::unique_ptr CreateReverse( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a Afterall instruction used for joining or creating new values of // token type which thread through side-effecting operations. Operands must // all be tokens, and there must be at least one operand. static std::unique_ptr CreateAfterAll( - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Creates an AfterAll instruction which creates a token type out of thin air // (no operands). This is a separate method from CreateAfterAll to facility @@ -1124,8 +1122,7 @@ class HloInstruction { // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). @@ -1505,7 +1502,7 @@ class HloInstruction { // Delegates to HloGatherInstruction::gather_dimension_numbers. const GatherDimensionNumbers& gather_dimension_numbers() const; // Delegates to HloGatherInstruction::gather_slice_sizes. - tensorflow::gtl::ArraySlice gather_slice_sizes() const; + absl::Span gather_slice_sizes() const; // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; @@ -1531,7 +1528,7 @@ class HloInstruction { // Removes a list of operands with the given indices in ascending order. void RemoveOperandsAtAscendingIndices( - tensorflow::gtl::ArraySlice ascending_indices); + absl::Span ascending_indices); void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); @@ -1561,8 +1558,7 @@ class HloInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { // TODO(b/80131774): This should be pure virtual. LOG(FATAL) << "Unimplemented method."; @@ -1608,7 +1604,7 @@ class HloInstruction { // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Adds a user for this instruction. void AddUser(HloInstruction* user); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e1c884d856d..68719537559 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -91,8 +91,7 @@ HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( std::unique_ptr HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); return absl::make_unique( @@ -113,8 +112,7 @@ HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( std::unique_ptr HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); return absl::make_unique( @@ -135,8 +133,7 @@ HloBatchNormGradInstruction::HloBatchNormGradInstruction( std::unique_ptr HloBatchNormGradInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); return absl::make_unique( @@ -144,9 +141,9 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl( new_operands[4], epsilon(), feature_index()); } -HloFftInstruction::HloFftInstruction( - const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) +HloFftInstruction::HloFftInstruction(const Shape& shape, + HloInstruction* operand, FftType fft_type, + absl::Span fft_length) : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { fft_length_.assign(fft_length.begin(), fft_length.end()); AppendOperand(operand); @@ -177,8 +174,7 @@ bool HloFftInstruction::IdenticalSlowPath( } std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique(shape, new_operands[0], fft_type_, @@ -232,8 +228,7 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand, } std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -250,8 +245,7 @@ HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, std::unique_ptr HloSendDoneInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -271,8 +265,7 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape, } std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -293,8 +286,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, std::unique_ptr HloRecvDoneInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -303,7 +295,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloCollectiveInstruction::HloCollectiveInstruction( HloOpcode opcode, const Shape& shape, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, const std::vector& replica_groups) : HloInstruction(opcode, shape), replica_groups_(replica_groups) { for (auto operand : operands) { @@ -344,7 +336,7 @@ bool HloCollectiveInstruction::IdenticalSlowPath( } HloAllReduceInstruction::HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id) @@ -392,8 +384,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( std::unique_ptr HloAllReduceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( shape, new_operands, to_apply(), replica_groups(), @@ -401,15 +392,14 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( } HloAllToAllInstruction::HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, const std::vector& replica_groups) : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, replica_groups) {} std::unique_ptr HloAllToAllInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique(shape, new_operands, replica_groups()); @@ -459,16 +449,15 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath( std::unique_ptr HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( shape, new_operands[0], source_target_pairs()); } -HloReverseInstruction::HloReverseInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) +HloReverseInstruction::HloReverseInstruction(const Shape& shape, + HloInstruction* operand, + absl::Span dimensions) : HloInstruction(HloOpcode::kReverse, shape), dimensions_(dimensions.begin(), dimensions.end()) { AppendOperand(operand); @@ -496,8 +485,7 @@ bool HloReverseInstruction::IdenticalSlowPath( } std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique(shape, new_operands[0], @@ -505,7 +493,7 @@ std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( } HloConcatenateInstruction::HloConcatenateInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension) : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { for (auto operand : operands) { @@ -537,16 +525,15 @@ bool HloConcatenateInstruction::IdenticalSlowPath( std::unique_ptr HloConcatenateInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(shape, new_operands, dimensions(0)); } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice args, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span args, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { @@ -581,8 +568,7 @@ bool HloReduceInstruction::IdenticalSlowPath( } std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size() % 2, 0); return absl::make_unique(shape, new_operands, @@ -621,8 +607,7 @@ bool HloSortInstruction::IdenticalSlowPath( } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; @@ -632,7 +617,7 @@ std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( HloTransposeInstruction::HloTransposeInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) + absl::Span dimensions) : HloInstruction(HloOpcode::kTranspose, shape), dimensions_(dimensions.begin(), dimensions.end()) { CHECK_EQ(shape.dimensions().size(), dimensions.size()); @@ -676,8 +661,7 @@ bool HloTransposeInstruction::IdenticalSlowPath( std::unique_ptr HloTransposeInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique(shape, new_operands[0], @@ -686,7 +670,7 @@ HloTransposeInstruction::CloneWithNewOperandsImpl( HloBroadcastInstruction::HloBroadcastInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimension) + absl::Span broadcast_dimension) : HloInstruction(HloOpcode::kBroadcast, shape), dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { AppendOperand(operand); @@ -715,17 +699,16 @@ bool HloBroadcastInstruction::IdenticalSlowPath( std::unique_ptr HloBroadcastInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique(shape, new_operands[0], dimensions()); } -HloMapInstruction::HloMapInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation) +HloMapInstruction::HloMapInstruction(const Shape& shape, + absl::Span operands, + HloComputation* map_computation) : HloInstruction(HloOpcode::kMap, shape) { for (auto operand : operands) { AppendOperand(operand); @@ -774,17 +757,16 @@ bool HloMapInstruction::IdenticalSlowPath( } std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(shape, new_operands, to_apply()); } -HloSliceInstruction::HloSliceInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) +HloSliceInstruction::HloSliceInstruction(const Shape& shape, + HloInstruction* operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) : HloInstruction(HloOpcode::kSlice, shape), slice_starts_(start_indices.begin(), start_indices.end()), slice_limits_(limit_indices.begin(), limit_indices.end()), @@ -835,8 +817,7 @@ bool HloSliceInstruction::IdenticalSlowPath( } std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -889,8 +870,7 @@ bool HloConstantInstruction::IdenticalSlowPath( std::unique_ptr HloConstantInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(literal_->CloneToUnique()); } @@ -947,8 +927,7 @@ bool HloTraceInstruction::IdenticalSlowPath( } std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); } @@ -966,7 +945,7 @@ HloFusionInstruction::HloFusionInstruction(const Shape& shape, HloFusionInstruction::HloFusionInstruction( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation) : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { for (auto operand : operands) { @@ -1373,8 +1352,7 @@ bool HloFusionInstruction::IdenticalSlowPath( } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloModule* module = context != nullptr ? context->module() : GetModule(); HloComputation* new_fused_computation = nullptr; @@ -1412,7 +1390,7 @@ Status HloFusionInstruction::DeduplicateFusionOperands() { HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters) + absl::Span parameters) : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { for (HloInstruction* param : parameters) { AppendOperand(param); @@ -1443,8 +1421,7 @@ bool HloRngInstruction::IdenticalSlowPath( } std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(shape, distribution_, new_operands); @@ -1480,8 +1457,7 @@ bool HloParameterInstruction::IdenticalSlowPath( std::unique_ptr HloParameterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(parameter_number_, shape, name()); @@ -1516,8 +1492,7 @@ bool HloGetTupleElementInstruction::IdenticalSlowPath( std::unique_ptr HloGetTupleElementInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -1559,8 +1534,7 @@ bool HloReducePrecisionInstruction::IdenticalSlowPath( std::unique_ptr HloReducePrecisionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -1600,8 +1574,7 @@ bool HloInfeedInstruction::IdenticalSlowPath( } std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( @@ -1646,8 +1619,7 @@ bool HloOutfeedInstruction::IdenticalSlowPath( } std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -1717,8 +1689,7 @@ bool HloConvolutionInstruction::IdenticalSlowPath( std::unique_ptr HloConvolutionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -1762,8 +1733,7 @@ bool HloReduceWindowInstruction::IdenticalSlowPath( std::unique_ptr HloReduceWindowInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -1811,8 +1781,7 @@ bool HloSelectAndScatterInstruction::IdenticalSlowPath( std::unique_ptr HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); return absl::make_unique( @@ -1821,7 +1790,7 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( } HloCustomCallInstruction::HloCustomCallInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, absl::string_view custom_call_target) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), @@ -1887,8 +1856,7 @@ bool HloCustomCallInstruction::IdenticalSlowPath( std::unique_ptr HloCustomCallInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { auto cloned = absl::make_unique( shape, new_operands, custom_call_target()); @@ -1931,8 +1899,7 @@ bool HloPadInstruction::IdenticalSlowPath( } std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique(shape, new_operands[0], @@ -1941,7 +1908,7 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes) + absl::Span slice_sizes) : HloInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); @@ -1971,8 +1938,7 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath( std::unique_ptr HloDynamicSliceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -1982,7 +1948,7 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( HloGatherInstruction::HloGatherInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes) + absl::Span slice_sizes) : HloInstruction(HloOpcode::kGather, shape) { AppendOperand(operand); AppendOperand(start_indices); @@ -2011,10 +1977,9 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { } /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice offset_dims, - tensorflow::gtl::ArraySlice collapsed_slice_dims, - tensorflow::gtl::ArraySlice start_index_map, - int64 index_vector_dim) { + absl::Span offset_dims, + absl::Span collapsed_slice_dims, + absl::Span start_index_map, int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; for (int64 output_window_dim : offset_dims) { gather_dim_numbers.add_offset_dims(output_window_dim); @@ -2057,8 +2022,7 @@ bool HloGatherInstruction::IdenticalSlowPath( } std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( @@ -2102,9 +2066,9 @@ string HloScatterInstruction::ScatterDimensionNumbersToString() const { /* static */ ScatterDimensionNumbers HloScatterInstruction::MakeScatterDimNumbers( - tensorflow::gtl::ArraySlice update_window_dims, - tensorflow::gtl::ArraySlice inserted_window_dims, - tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + absl::Span update_window_dims, + absl::Span inserted_window_dims, + absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim) { ScatterDimensionNumbers scatter_dim_numbers; for (int64 update_window_dim : update_window_dims) { @@ -2144,8 +2108,7 @@ bool HloScatterInstruction::IdenticalSlowPath( } std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); return absl::make_unique( @@ -2177,8 +2140,7 @@ bool HloIotaInstruction::IdenticalSlowPath( } std::unique_ptr HloIotaInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { return absl::make_unique(shape, iota_dimension()); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 4fe5144aca4..45a648bbe4c 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -67,8 +67,7 @@ class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -82,8 +81,7 @@ class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -97,8 +95,7 @@ class HloBatchNormGradInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -106,7 +103,7 @@ class HloFftInstruction : public HloInstruction { public: explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); FftType fft_type() const { return fft_type_; } const std::vector& fft_length() const { return fft_length_; } @@ -124,8 +121,7 @@ class HloFftInstruction : public HloInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes FFT type for an FFT instruction. @@ -174,8 +170,7 @@ class HloSendInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -187,8 +182,7 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -200,8 +194,7 @@ class HloRecvInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -213,8 +206,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -227,7 +219,7 @@ class HloCollectiveInstruction : public HloInstruction { protected: explicit HloCollectiveInstruction( HloOpcode opcode, const Shape& shape, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, const std::vector& replica_groups); HloInstructionProto ToProto() const override; @@ -245,7 +237,7 @@ class HloCollectiveInstruction : public HloInstruction { class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id); @@ -274,8 +266,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The string representation of the barrier config used for CrossReplicaSum. @@ -290,14 +281,13 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { class HloAllToAllInstruction : public HloCollectiveInstruction { public: explicit HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, const std::vector& replica_groups); private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -324,8 +314,7 @@ class HloCollectivePermuteInstruction : public HloInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; const std::vector> source_target_pairs_; @@ -334,7 +323,7 @@ class HloCollectivePermuteInstruction : public HloInstruction { class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -350,8 +339,7 @@ class HloReverseInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -359,9 +347,9 @@ class HloReverseInstruction : public HloInstruction { class HloConcatenateInstruction : public HloInstruction { public: - explicit HloConcatenateInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - int64 dimension); + explicit HloConcatenateInstruction(const Shape& shape, + absl::Span operands, + int64 dimension); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -379,8 +367,7 @@ class HloConcatenateInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -388,10 +375,10 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: - explicit HloReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice args, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reduce_computation); + explicit HloReduceInstruction(const Shape& shape, + absl::Span args, + absl::Span dimensions_to_reduce, + HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -403,12 +390,12 @@ class HloReduceInstruction : public HloInstruction { int64 input_count() const { return operand_count() / 2; } // Returns the input tensors to be reduced. - tensorflow::gtl::ArraySlice inputs() const { + absl::Span inputs() const { return absl::MakeSpan(operands()).subspan(0, input_count()); } // Returns the init values of the reduction. - tensorflow::gtl::ArraySlice init_values() const { + absl::Span init_values() const { return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); } @@ -421,8 +408,7 @@ class HloReduceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -450,8 +436,7 @@ class HloSortInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -459,9 +444,8 @@ class HloSortInstruction : public HloInstruction { class HloTransposeInstruction : public HloInstruction { public: - explicit HloTransposeInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, + absl::Span dimensions); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -479,8 +463,7 @@ class HloTransposeInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -488,9 +471,8 @@ class HloTransposeInstruction : public HloInstruction { class HloBroadcastInstruction : public HloInstruction { public: - explicit HloBroadcastInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimension); + explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand, + absl::Span broadcast_dimension); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -506,8 +488,7 @@ class HloBroadcastInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -515,9 +496,9 @@ class HloBroadcastInstruction : public HloInstruction { class HloMapInstruction : public HloInstruction { public: - explicit HloMapInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation); + explicit HloMapInstruction(const Shape& shape, + absl::Span operands, + HloComputation* map_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -535,8 +516,7 @@ class HloMapInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -545,9 +525,9 @@ class HloMapInstruction : public HloInstruction { class HloSliceInstruction : public HloInstruction { public: explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); HloInstructionProto ToProto() const override; @@ -586,8 +566,7 @@ class HloSliceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes the [begin, end) index range for a slice. @@ -629,8 +608,7 @@ class HloConstantInstruction : public HloInstruction { CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // TODO(b/36360764): Remove unique_ptr wrapping. std::unique_ptr literal_; @@ -651,8 +629,7 @@ class HloTraceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // TODO(b/36360764): Remove unique_ptr wrapping. std::unique_ptr literal_; @@ -663,10 +640,9 @@ class HloFusionInstruction : public HloInstruction { explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); - explicit HloFusionInstruction( - const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, - HloComputation* fusion_computation); + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + absl::Span operands, + HloComputation* fusion_computation); string ToCategory() const override; // Returns a serialized representation of this instruction. @@ -779,8 +755,7 @@ class HloFusionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The type of the fusion. Used by kFusion only. @@ -789,9 +764,9 @@ class HloFusionInstruction : public HloInstruction { class HloRngInstruction : public HloInstruction { public: - explicit HloRngInstruction( - const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters); + explicit HloRngInstruction(const Shape& shape, + RandomDistribution distribution, + absl::Span parameters); // Returns the random distribution for this rng node. RandomDistribution random_distribution() const { return distribution_; } // Returns a serialized representation of this instruction. @@ -808,8 +783,7 @@ class HloRngInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The distribution requested for random number generation. @@ -834,8 +808,7 @@ class HloParameterInstruction : public HloInstruction { CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; int64 parameter_number_ = 0; @@ -859,8 +832,7 @@ class HloGetTupleElementInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; int64 tuple_index_ = -1; @@ -888,8 +860,7 @@ class HloReducePrecisionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The bit sizes for a reduce-precision operation. @@ -926,8 +897,7 @@ class HloInfeedInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The string representation of the infeed configuration. @@ -959,8 +929,7 @@ class HloOutfeedInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Shape of outfeed request. @@ -1001,8 +970,7 @@ class HloConvolutionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; // Describes the dimension numbers used for a convolution. @@ -1033,8 +1001,7 @@ class HloReduceWindowInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; }; @@ -1082,17 +1049,16 @@ class HloSelectAndScatterInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; }; class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - absl::string_view custom_call_target); + explicit HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target); const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1125,8 +1091,7 @@ class HloCustomCallInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1155,8 +1120,7 @@ class HloPadInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The padding configuration that describes the edge padding and interior @@ -1166,10 +1130,10 @@ class HloPadInstruction : public HloInstruction { class HloDynamicSliceInstruction : public HloInstruction { public: - explicit HloDynamicSliceInstruction( - const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + explicit HloDynamicSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* start_indices, + absl::Span slice_sizes); // Old methods kept for smooth subclassing transition END. // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1191,8 +1155,7 @@ class HloDynamicSliceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes the [start, start + size) range size for a dynamic slice @@ -1206,12 +1169,12 @@ class HloGatherInstruction : public HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); const GatherDimensionNumbers& gather_dimension_numbers() const { CHECK(gather_dimension_numbers_ != nullptr); return *gather_dimension_numbers_; } - tensorflow::gtl::ArraySlice gather_slice_sizes() const { + absl::Span gather_slice_sizes() const { return gather_slice_sizes_; } // Returns the dump string of the gather dimension numbers. @@ -1221,10 +1184,9 @@ class HloGatherInstruction : public HloInstruction { // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice offset_dims, - tensorflow::gtl::ArraySlice collapsed_slice_dims, - tensorflow::gtl::ArraySlice start_index_map, - int64 index_vector_dim); + absl::Span offset_dims, + absl::Span collapsed_slice_dims, + absl::Span start_index_map, int64 index_vector_dim); private: std::vector ExtraAttributesToStringImpl( @@ -1234,8 +1196,7 @@ class HloGatherInstruction : public HloInstruction { const std::function& eq_computations) const override; std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::unique_ptr gather_dimension_numbers_; @@ -1260,9 +1221,9 @@ class HloScatterInstruction : public HloInstruction { // Creates an instance of ScatterDimensionNumbers. static ScatterDimensionNumbers MakeScatterDimNumbers( - tensorflow::gtl::ArraySlice update_window_dims, - tensorflow::gtl::ArraySlice inserted_window_dims, - tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + absl::Span update_window_dims, + absl::Span inserted_window_dims, + absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim); private: @@ -1274,8 +1235,7 @@ class HloScatterInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::unique_ptr scatter_dimension_numbers_; @@ -1298,8 +1258,7 @@ class HloIotaInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; const int64 iota_dimension_; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 78167335c8e..3a1bc4e328b 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -353,7 +353,7 @@ bool IsUsedOutsideSubcomputation( } // anonymous namespace HloInstruction* HloModule::OutlineExpressionFromComputation( - tensorflow::gtl::ArraySlice instructions_to_outline, + absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation) { auto builder = HloComputation::Builder(outlined_computation_name); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index cf129b835db..ee5601beec2 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -192,7 +192,7 @@ class HloModule { // order (root of outlined instructions last). TODO(jingyue): takes a set of // instructions and topologically sorts them. HloInstruction* OutlineExpressionFromComputation( - tensorflow::gtl::ArraySlice instructions_to_outline, + absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation); // Returns a randomly generated uint64. diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index d70328c8a3d..d83ee714905 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -193,7 +193,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( } std::vector HloModuleGroupUtil::RootInstructions( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { std::vector roots; for (HloComputation* computation : computations) { for (HloInstruction* instruction : computation->instructions()) { @@ -293,7 +293,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( } Status HloModuleGroupUtil::VerifyComputations( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { auto visit_function = [&](HloInstruction* instruction, const std::vector& instruction_group) { @@ -324,7 +324,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr> HloModuleGroupUtil::ComputeReachability( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { std::vector post_order; auto visit_function = [&](HloInstruction* instruction, diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index c25ca1aff50..fe11fe18180 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -56,7 +56,7 @@ class HloModuleGroupUtil { // Returns the root instructions of the computations. std::vector RootInstructions( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Visit state of each instruction during DFS traversal. enum VisitState { @@ -93,15 +93,14 @@ class HloModuleGroupUtil { HloInstruction* root); // Verifies that the computations are well-formed (e.g., no cycles). - Status VerifyComputations( - tensorflow::gtl::ArraySlice computations); + Status VerifyComputations(absl::Span computations); // Below Reachability utils resemble those in HloComputation, except that // they can handle instructions across multiple computations. // // Creates the reachability map for the instructions in the computations. StatusOr> ComputeReachability( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Updates the reachability of the given instruction, taking the global // predeccessorss and successors into account. diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 209ad5e58c9..80009c7f7ed 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -44,7 +44,7 @@ class HloModuleTest : public HloTestBase { // Creates a computation which calls the given zero-parameter computations. std::unique_ptr CreateCallComputation( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { auto builder = HloComputation::Builder("Call"); for (auto computation : computations) { builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index b93e4f24f64..02201d45421 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -306,7 +306,7 @@ bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { // Creates replica groups from the provided nested array. groups[i] represents // the replica ids for group 'i'. std::vector CreateReplicaGroups( - tensorflow::gtl::ArraySlice> groups) { + absl::Span> groups) { std::vector replica_groups; absl::c_transform(groups, std::back_inserter(replica_groups), [](const std::vector& ids) { @@ -997,10 +997,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } instruction = builder->AddInstruction(HloInstruction::CreateReduce( shape, /*operands=*/ - tensorflow::gtl::ArraySlice(operands).subspan( + absl::Span(operands).subspan( 0, operands.size() / 2), /*init_values=*/ - tensorflow::gtl::ArraySlice(operands).subspan( + absl::Span(operands).subspan( operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 01b088a9575..961930f0a88 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - tensorflow::gtl::ArraySlice instructions) + absl::Span instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { @@ -29,7 +29,7 @@ HloReachabilityMap::HloReachabilityMap( } bool HloReachabilityMap::SetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; @@ -38,13 +38,13 @@ bool HloReachabilityMap::SetReachabilityToUnion( } void HloReachabilityMap::FastSetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction) { SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); } void HloReachabilityMap::SetReachabilityToUnionHelper( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 48215d32a82..2c8ebc8e6c4 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -42,7 +42,7 @@ class HloReachabilityMap { // Sets up a graph with no edges and where the nodes correspond to the given // instructions. explicit HloReachabilityMap( - tensorflow::gtl::ArraySlice instructions); + absl::Span instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where @@ -54,13 +54,12 @@ class HloReachabilityMap { // vector in the internal graph of this HloReachabilityMap for the given // instruction and does not transitively update any other part of the // adjacency matrix. - bool SetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, - const HloInstruction* instruction); + bool SetReachabilityToUnion(absl::Span inputs, + const HloInstruction* instruction); // As above, but faster because it does not check if the reachability changed. void FastSetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true @@ -141,7 +140,7 @@ class HloReachabilityMap { // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. void SetReachabilityToUnionHelper( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector); // Return the index of the given instruction. The value is used to index into diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 569d2e5d2d9..c9629926eae 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -202,8 +202,8 @@ class InstructionList { // On object construction this ordinal is precisely the instruction's index // in the list. Later, instructions inserted via InsertBefore receive // duplicate values. However, monotonicity is preserved. - void InsertBeforeInstructions( - Item* to_insert, tensorflow::gtl::ArraySlice before_instructions) { + void InsertBeforeInstructions(Item* to_insert, + absl::Span before_instructions) { VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() << " before {" << absl::StrJoin(before_instructions, ", ", diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 7bd8a4a544b..66ac1f66fd0 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -106,7 +106,7 @@ StatusOr HloRunner::TransferLiteralToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice literals) { + const absl::Span literals) { std::vector buffers; for (const Literal* literal : literals) { CHECK(literal != nullptr); @@ -118,7 +118,7 @@ StatusOr> HloRunner::TransferLiteralsToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice> literals) { + const absl::Span> literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { @@ -137,8 +137,8 @@ StatusOr> HloRunner::TransferLiteralFromDevice( StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, @@ -152,7 +152,7 @@ StatusOr> HloRunner::Execute( StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice> arguments, + const absl::Span> arguments, bool run_hlo_passes, ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector argument_pointers; @@ -169,8 +169,8 @@ StatusOr> HloRunner::Execute( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { // Get service run options. se::Stream stream(backend().default_stream_executor()); stream.Init(); @@ -190,8 +190,8 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { @@ -226,8 +226,7 @@ StatusOr>> HloRunner::ExecuteReplicated( // no arguments. std::vector argument_buffer_ptrs( options.num_replicas * options.arguments.size() + 1); - std::vector> - argument_buffer_slices; + std::vector> argument_buffer_slices; int64 index = 0; for (int64 i = 0; i < options.num_replicas; ++i) { int64 device = device_assignment(i, 0); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index cfc519063e8..547b5fc1bb7 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -104,9 +104,9 @@ class HloRunner { // Transfers data between the host and device. StatusOr TransferLiteralToDevice(const Literal& literal); StatusOr> TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice literals); + const absl::Span literals); StatusOr> TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice> literals); + const absl::Span> literals); StatusOr> TransferLiteralFromDevice( const ShapedBuffer& buffer); @@ -117,24 +117,24 @@ class HloRunner { // optimization. StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice> arguments, + const absl::Span> arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // Executes a given HLO module into a set of replicas, and returns a map diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 930801288a0..d49d09d4597 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -269,7 +269,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto abs_abs1 = builder.AddInstruction( HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice({abs_abs1}))); + absl::Span({abs_abs1}))); auto tuple_elm = builder.AddInstruction( HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 12352597647..de7e6b53d4d 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -54,9 +54,8 @@ HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { return HloSharding(flattened_list); } -HloSharding HloSharding::Tuple( - const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { +HloSharding HloSharding::Tuple(const Shape& tuple_shape, + absl::Span shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); for (auto& sharding : shardings) { CHECK(!sharding.IsTuple()) << sharding.ToString(); @@ -142,7 +141,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!maximal_); CHECK(!IsTuple()); std::vector ret_index; - tile_assignment_.Each([&](tensorflow::gtl::ArraySlice index, int64 d) { + tile_assignment_.Each([&](absl::Span index, int64 d) { if (d == device) { ret_index = {index.begin(), index.end()}; } @@ -151,8 +150,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { return ret_index; } -int64 HloSharding::DeviceForTileIndex( - tensorflow::gtl::ArraySlice index) const { +int64 HloSharding::DeviceForTileIndex(absl::Span index) const { CHECK(!replicated_); CHECK(!IsTuple()); if (maximal_) { @@ -319,7 +317,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, Status status = Status::OK(); std::set seen_cores; tile_assignment_.Each( - [&](tensorflow::gtl::ArraySlice indices, int32 core) { + [&](absl::Span indices, int32 core) { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index be51c3f55b5..01fd9f215df 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -66,7 +66,7 @@ class HloSharding { // shardings must match the number of leaf nodes in tuple_shape. For // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings); + absl::Span shardings); // Creates a new sharding for a tuple type, with a single input sharding // repeated on each leaf. @@ -132,7 +132,7 @@ class HloSharding { // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. // REQUIRES: !IsTuple() - int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; + int64 DeviceForTileIndex(absl::Span index) const; // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 6e9b96488cf..34cba6136ff 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -372,7 +372,7 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, } StatusOr> ExtractOriginalCommonSharding( - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { // If we are here, all the instructions being passed had the same sharding // (or no sharding), by the means of the ShardingMatches() API. // As such, no kDomain was inserted, and here we are asked to extract the diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 2341f8ada0d..80634677e78 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -29,8 +29,8 @@ limitations under the License. namespace xla { namespace { -Array MakeArray(tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice contents) { +Array MakeArray(absl::Span dimensions, + absl::Span contents) { Array a(dimensions); std::copy(contents.begin(), contents.end(), a.begin()); return a; diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index e0c13261772..773fc7d2253 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -149,7 +149,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, } // namespace void HloValue::SetPositionsAndComputeUses( - tensorflow::gtl::ArraySlice positions) { + absl::Span positions) { CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; // The positions must be unique and should not contain the defining position @@ -222,8 +222,7 @@ string HloValueSet::ToString() const { })); } -bool HloValueSet::AssignUnionOf( - tensorflow::gtl::ArraySlice inputs) { +bool HloValueSet::AssignUnionOf(absl::Span inputs) { HloValueSet union_set; for (const HloValueSet* input : inputs) { for (const HloValue* value : input->values()) { @@ -254,7 +253,7 @@ std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { } bool InstructionValueSet::AssignUnionOf( - tensorflow::gtl::ArraySlice inputs) { + absl::Span inputs) { CHECK_GT(inputs.size(), 0); for (int i = 1; i < inputs.size(); ++i) { DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index a1151f65e07..6f2ad214f6b 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -108,8 +108,7 @@ class HloValue : public BufferValue { // Sets the positions in the module at which the HloValue appears. Updates // uses. Should be called once and only once. The defining position should not // be included in 'positions' as this is set at construction time. - void SetPositionsAndComputeUses( - tensorflow::gtl::ArraySlice positions); + void SetPositionsAndComputeUses(absl::Span positions); // Returns whether this value is a phi value. bool is_phi() const { return is_phi_; } @@ -186,14 +185,14 @@ class HloValueSet { public: HloValueSet() = default; - explicit HloValueSet(tensorflow::gtl::ArraySlice values) + explicit HloValueSet(absl::Span values) : values_(values.begin(), values.end()) { SortAndUniquifyValues(); } // Sets this value set to the union of the given value sets. Returns whether // this value set changed. - bool AssignUnionOf(tensorflow::gtl::ArraySlice inputs); + bool AssignUnionOf(absl::Span inputs); // Return the vector of HloValues in the set. Values in the vector are unique // and stably sorted by value id. @@ -247,8 +246,7 @@ class InstructionValueSet : public ShapeTree { // Sets this value set to the union of the given value sets. Returns whether // this value set changed. - bool AssignUnionOf( - tensorflow::gtl::ArraySlice inputs); + bool AssignUnionOf(absl::Span inputs); string ToString() const; }; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 744cd64bc5a..95516dec74b 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -699,8 +699,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { instruction->opcode(), instruction->operands())); } -string ComputationsToString( - tensorflow::gtl::ArraySlice computations) { +string ComputationsToString(absl::Span computations) { return absl::StrJoin(computations, ",", [](string* s, const HloComputation* computation) { s->append(computation->name()); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 4d4f681c8a2..a4de02a8903 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -35,7 +35,6 @@ using ConstantArray = Analysis::ConstantArray; using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; using absl::StrJoin; -using tensorflow::gtl::ArraySlice; } // namespace string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { @@ -186,7 +185,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( StatusOr IndexedArrayAnalysis::FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, - tensorflow::gtl::ArraySlice output_dims, Shape shape) { + absl::Span output_dims, Shape shape) { // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). // `source` is the inner Gather(A, X). @@ -252,8 +251,7 @@ StatusOr IndexedArrayAnalysis::FoldGatherOfGather( StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes, Array* source, - Array* indices) { + absl::Span slice_sizes, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; @@ -314,7 +312,7 @@ namespace { // Returns an index into `values` such that the product of the range // [values.begin()+index, values.end()) is equal to `product`. If there is no // such index, return -1. All integers in `values` must be positive. -int64 FindSuffixWithProduct(ArraySlice values, int64 product) { +int64 FindSuffixWithProduct(absl::Span values, int64 product) { DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; })); int64 current_product = 1; @@ -343,7 +341,8 @@ struct ReshapePassthroughDimPair { // The returned vector of pairs is sorted in both the result_dim and the // operand_dim components. std::vector ComputeReshapePassthroughDimPairs( - ArraySlice operand_shape, ArraySlice result_shape) { + absl::Span operand_shape, + absl::Span result_shape) { // A reshape can be seen as an index mapping from output index to input index: // // (i_0, ..., i_n) = f(o_0, ..., o_m) @@ -420,7 +419,7 @@ std::vector ComputeReshapePassthroughDimPairs( // Return true if `dim` is stated as an passthrough operand dim in // `passthrough_dims`. bool IsReshapePassthroughOperandDim( - ArraySlice passthrough_dims, int64 dim) { + absl::Span passthrough_dims, int64 dim) { return absl::c_any_of(passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { return passthrough_dim_pair.operand_dim == dim; @@ -430,7 +429,8 @@ bool IsReshapePassthroughOperandDim( // Maps `operand_dim` which must be an passthrough operand dimension to its // corresponding passthrough result dimension based on `passthrough_dims`. int64 MapPassthroughOperandDimToResultDim( - ArraySlice passthrough_dims, int64 operand_dim) { + absl::Span passthrough_dims, + int64 operand_dim) { auto it = absl::c_find_if( passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { return passthrough_dim_pair.operand_dim == operand_dim; @@ -439,9 +439,9 @@ int64 MapPassthroughOperandDimToResultDim( return it->result_dim; } -int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, - ArraySlice result_shape, - int64 source_passthrough_dim) { +int64 FindSourcePositionForPassthroughResultDim( + absl::Span operand_shape, absl::Span result_shape, + int64 source_passthrough_dim) { VLOG(3) << "FindSourcePositionForPassthroughResultDim([" << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") << "], " << source_passthrough_dim << ")"; @@ -519,8 +519,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( } StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, - tensorflow::gtl::ArraySlice degenerate_dims) { + ScalarIndexedArray* operand, absl::Span degenerate_dims) { if (degenerate_dims.empty()) { return operand; } @@ -873,7 +872,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, return nullptr; } - ArraySlice broadcast_dims = broadcast_instr->dimensions(); + absl::Span broadcast_dims = broadcast_instr->dimensions(); auto is_broadcasted_dim = [&](int64 output_dim) { return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end(); }; @@ -896,7 +895,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // The scalar-indexed node "removes" the source dim and "inserts" the output // dims. We do the opposite here to undo the scalar-indexed operation. - ArraySlice output_dims = scalar_indexed_const->output_dims(); + absl::Span output_dims = scalar_indexed_const->output_dims(); for (int64 i = output_dims.size() - 1; i >= 0; --i) { CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); EraseAt(&simulated_index, output_dims[i]); @@ -973,8 +972,8 @@ namespace { // Returns the non-contracting non-batch dimension (as per `contracting_dims` // and `batch_dims`) if there is exactly one, otherwise returns nullopt. absl::optional GetOnlyNonContractingNonBatchDim( - int64 rank, ArraySlice contracting_dims, - ArraySlice batch_dims) { + int64 rank, absl::Span contracting_dims, + absl::Span batch_dims) { absl::optional result; for (int64 dim = 0; dim < rank; dim++) { if (!absl::c_linear_search(contracting_dims, dim) && @@ -998,7 +997,8 @@ absl::optional GetOnlyNonContractingNonBatchDim( // of whatever operand `indexed_array` is to the dot (LHS or RHS). bool CanFoldDotIntoIndexedArray( absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array, - ArraySlice contracting_dims, ArraySlice batch_dims) { + absl::Span contracting_dims, + absl::Span batch_dims) { absl::optional non_contracting_non_batch_dim = GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), contracting_dims, batch_dims); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 3fa7d749e19..dcfb7255358 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -188,9 +188,7 @@ class IndexedArrayAnalysis { // `output_dims` are the dimensions in the output array that are being used // to compute an index into the `indices` array. See the class // documentation and the overview for more details. - tensorflow::gtl::ArraySlice output_dims() const { - return output_dims_; - } + absl::Span output_dims() const { return output_dims_; } private: explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, @@ -265,8 +263,7 @@ class IndexedArrayAnalysis { StatusOr ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes, Array* source, - Array* indices); + absl::Span slice_sizes, Array* source, Array* indices); StatusOr ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, @@ -303,7 +300,7 @@ class IndexedArrayAnalysis { // G1 = [Arr[i] for i in I2] StatusOr FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, - tensorflow::gtl::ArraySlice output_dims, Shape shape); + absl::Span output_dims, Shape shape); // Reshapes a scalar-indexed node to remove the degenerate dimensions in its // output. The result is always a scalar-indexed node. @@ -313,8 +310,7 @@ class IndexedArrayAnalysis { // Reshapes a scalar-indexed node such that the result has the degenerate // dimensions `degenerate_dims`. The result is always a scalar-indexed node. StatusOr ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, - tensorflow::gtl::ArraySlice degenerate_dims); + ScalarIndexedArray* operand, absl::Span degenerate_dims); StatusOr FoldReshapeOfGather( const Shape& shape, ScalarIndexedConstantArray* operand); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 4b5285031bb..8c907eae0cb 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -219,7 +219,7 @@ bool InstructionFusion::CanFuseOnAllPaths( InstructionFusion::HloInstructionSet InstructionFusion::ComputeGloballyUnfusible( - tensorflow::gtl::ArraySlice post_order) { + absl::Span post_order) { // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers // via all paths. diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 9802d4cfc1b..00b658959a2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -123,7 +123,7 @@ class InstructionFusion : public HloPassInterface { // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. HloInstructionSet ComputeGloballyUnfusible( - tensorflow::gtl::ArraySlice post_order); + absl::Span post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 2259dc1083e..5dea1247684 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -47,7 +47,7 @@ InterpreterExecutable::~InterpreterExecutable() {} StatusOr InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); se::StreamExecutor* executor = stream->parent(); @@ -111,7 +111,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return tensorflow::errors::Unimplemented( "ExecuteAsyncOnStream is not yet supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 91d8148d26d..588787d445f 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -48,13 +48,13 @@ class InterpreterExecutable : public Executable { StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override LOCKS_EXCLUDED(evaluator_lock_); StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; static int64 ShapeSizeBytes(const Shape& shape); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index db6b910b32f..f600b14c6c5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -47,7 +47,7 @@ limitations under the License. namespace stream_executor { namespace interpreter { -using Args = tensorflow::gtl::ArraySlice; +using Args = absl::Span; class XlaInterpreterExecutor : public internal::StreamExecutorInterface { public: diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index ad350613dd2..cc2e862f2eb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -99,9 +99,10 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name); } -Status EmitDynamicUpdateSliceInPlace( - tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) { +Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, + const IrArray& output_array, + absl::string_view name, + llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the @@ -129,8 +130,7 @@ Status EmitDynamicUpdateSliceInPlace( // // Emits a sequential loop if launch_dimensions is null. static Status EmitFusedDynamicUpdateSliceInPlaceImpl( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); @@ -173,8 +173,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( } Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( @@ -183,8 +182,7 @@ Status EmitFusedDynamicUpdateSliceInPlace( } Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index e1631a62ae8..fb3e4eb97ca 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -63,25 +63,24 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( // Emits IR for running the given dynamic-update-slice op in-place -- that is, // where the input and output buffers share the same slice, so we can simply // modify the input/output buffer without touching any of the other elements. -Status EmitDynamicUpdateSliceInPlace( - tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b); +Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, + const IrArray& output_array, + absl::string_view name, + llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits // (sequential) code for a fusion node that does the dynamic-update-slice in // place. Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 6d637cad6df..b606c993a2d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -147,7 +147,7 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { } Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); + absl::Span operands(tuple->operands()); std::vector operand_elemental_ir_types; for (HloInstruction* operand : operands) { operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 30471480c4f..25ec458160b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -54,7 +54,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator; - FusedIrEmitter(tensorflow::gtl::ArraySlice parameter_arrays, + FusedIrEmitter(absl::Span parameter_arrays, ElementalIrEmitter* elemental_emitter) : parameter_arrays_(parameter_arrays), tiled_parameter_info_(nullptr), @@ -94,7 +94,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { private: // Arrays of parameters of fusion instruction - tensorflow::gtl::ArraySlice parameter_arrays_; + absl::Span parameter_arrays_; const llvm_ir::TiledParameterInfo* tiled_parameter_info_; ElementalIrEmitter* elemental_emitter_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 36e713d1ac8..67f74231211 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -73,7 +73,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, Delinearize(&multidim_, linear, shape, b); } -IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, +IrArray::Index::Index(absl::Span multidim, llvm::Value* linear, const Shape& shape) : multidim_(multidim.begin(), multidim.end()), linear_(linear), @@ -92,7 +92,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, << " should have a layout."; } -IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, +IrArray::Index::Index(absl::Span multidim, const Shape& shape, llvm::IRBuilder<>* b) : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), @@ -147,7 +147,7 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { llvm::Value* logical_linear_index = - Index(tensorflow::gtl::ArraySlice(multidim_).subspan( + Index(absl::Span(multidim_).subspan( common_factors[k].second, common_factors[k + 1].second - common_factors[k].second), index_type_) @@ -184,9 +184,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } IrArray::Index IrArray::Index::SourceIndexOfSlice( - const Shape& shape, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice strides, - llvm::IRBuilder<>* builder) const { + const Shape& shape, absl::Span starts, + absl::Span strides, llvm::IRBuilder<>* builder) const { Index source_index(index_type_, multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; @@ -207,7 +206,7 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice( IrArray::Index IrArray::Index::SourceIndexOfTranspose( const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, + absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); @@ -256,7 +255,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, + absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { int64 rank = ShapeUtil::Rank(operand_shape); std::vector source_index(rank); @@ -321,9 +320,8 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( return Index(source_index, linear, operand_shape); } -llvm::Value* IrArray::Index::Linearize( - tensorflow::gtl::ArraySlice dimensions, - llvm::IRBuilder<>* builder) const { +llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, + llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. CHECK_EQ(size(), dimensions.size()); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index e913c109b3f..7629806a365 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -70,7 +70,7 @@ class IrArray { // Constructs an index from multi-dimensional index "multidim". The linear // index is set to nullptr. - explicit Index(tensorflow::gtl::ArraySlice multidim, + explicit Index(absl::Span multidim, llvm::Type* index_ty = nullptr) : multidim_(multidim.begin(), multidim.end()) { if (size() == 0) { @@ -99,14 +99,14 @@ class IrArray { // that it indexes into. // // Precondition: "shape" has a layout. - Index(tensorflow::gtl::ArraySlice multidim, - const Shape& shape, llvm::IRBuilder<>* b); + Index(absl::Span multidim, const Shape& shape, + llvm::IRBuilder<>* b); // Constructs an index from both a multi-dimensional index and a linear // index. "shape" has the same meaning as that in the constructor that takes // only a linear index. - Index(tensorflow::gtl::ArraySlice multidim, - llvm::Value* linear, const Shape& shape); + Index(absl::Span multidim, llvm::Value* linear, + const Shape& shape); const std::vector& multidim() const { return multidim_; } llvm::Value* linear() const { return linear_; } @@ -145,17 +145,15 @@ class IrArray { // by starting indices `starts` and stride values `strides`. // // Precondition: "this" is an index into a slice whose shape is `shape`. - Index SourceIndexOfSlice(const Shape& shape, - tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice strides, + Index SourceIndexOfSlice(const Shape& shape, absl::Span starts, + absl::Span strides, llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a transpose from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfTranspose( - const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping, + llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a bitcast from `operand_shape` // to `shape`, returns the source index. @@ -164,14 +162,13 @@ class IrArray { // Given that "this" is the target index of a broadcast from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfBroadcast( - const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping, + llvm::IRBuilder<>* builder) const; // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and // returns the index into the sole dimension 0 of the new shape. - llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, + llvm::Value* Linearize(absl::Span dimensions, llvm::IRBuilder<>* builder) const; llvm::Type* GetType() const { return index_type_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index b152cf9275c..43fec311f15 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -235,7 +235,7 @@ class KernelSupportLibrary { })); } - using ArgumentVector = tensorflow::gtl::ArraySlice; + using ArgumentVector = absl::Span; // Generates the following control flow structure: // diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index cb4d1db997c..e5fbdbd51b8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -28,7 +28,7 @@ namespace { // Returns the indices of the first elements of all consecutive subarrays of the // given array. For example: // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { +std::vector ConsecutiveSegments(absl::Span xs) { std::vector is = {0}; for (size_t i = 1; i < xs.size(); ++i) { if (1 != xs[i] - xs[i - 1]) { @@ -40,8 +40,7 @@ std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { // Merges the sequences of dimensions of the given shape which start at the // given indices `segs`. -Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, - const Shape& shape) { +Shape MergeDimensions(absl::Span segs, const Shape& shape) { std::vector dimensions; for (size_t i = 1; i <= segs.size(); ++i) { dimensions.push_back(std::accumulate( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 8bd06c42c3c..5ea05b3188a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -50,7 +50,7 @@ IrArray::Index GetUnreducedOutputIndex( // for 021 transpose. class TiledParameterInfo { public: - TiledParameterInfo(tensorflow::gtl::ArraySlice param_buffers, + TiledParameterInfo(absl::Span param_buffers, llvm::Value* y, llvm::Value* x) : param_buffers_(param_buffers), y_(y), x_(x) {} @@ -67,7 +67,7 @@ class TiledParameterInfo { private: // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr // if the parameter is not tiled. - tensorflow::gtl::ArraySlice param_buffers_; + absl::Span param_buffers_; // The y coordinate within a tile. llvm::Value* y_; // The x coordinate within a tile. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 9f3329e7f0e..219a9f221fb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -241,7 +241,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, } IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( - const Shape& shape, tensorflow::gtl::ArraySlice dimensions, + const Shape& shape, absl::Span dimensions, absl::string_view suffix) { llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 0a406bd90b9..2be7bbd0de9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -242,7 +242,7 @@ class ForLoopNest { // size equals the rank of shape and there is a null for each // dimension that is not in "dimensions". IrArray::Index AddLoopsForShapeOnDimensions( - const Shape& shape, tensorflow::gtl::ArraySlice dimensions, + const Shape& shape, absl::Span dimensions, absl::string_view suffix); // Emits a series of nested loops for iterating over an operand array. Loops diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index f0db2a3761a..1a53c026be3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -83,11 +83,10 @@ string DumpModuleToString(const llvm::Module& module) { return AsString(buffer_string); } -llvm::Value* EmitCallToIntrinsic( - llvm::Intrinsic::ID intrinsic_id, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* b) { +llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, + absl::Span operands, + absl::Span overloaded_types, + llvm::IRBuilder<>* b) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index dde50e19d1c..61b029eb082 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -59,7 +59,7 @@ llvm::ArrayRef AsArrayRef(const std::vector& vec) { } template -llvm::ArrayRef AsArrayRef(const tensorflow::gtl::ArraySlice& slice) { +llvm::ArrayRef AsArrayRef(const absl::Span& slice) { return llvm::ArrayRef(slice.data(), slice.size()); } @@ -101,11 +101,10 @@ string SanitizeFunctionName(string function_name); // intrinsics (for example, "minnum") must include a type in overloaded_types // for each overloaded type. Typically, overloaded intrinsics have only a single // overloaded type. -llvm::Value* EmitCallToIntrinsic( - llvm::Intrinsic::ID intrinsic_id, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* b); +llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, + absl::Span operands, + absl::Span overloaded_types, + llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 1553b4fc91e..0dc120e0b0d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -69,7 +69,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( } LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, + absl::Span target_arrays, llvm::IRBuilder<>* b) : body_emitter_(MakeBodyEmitterForMultiOutputFusion( target_element_generator, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 57d9d8bbc61..a537c00066b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -53,8 +53,7 @@ class LoopEmitter { // This is used for multi-output fusion. target_element_generator must // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, - llvm::IRBuilder<>* b); + absl::Span target_arrays, llvm::IRBuilder<>* b); LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 11ed6ee59f1..7d49b8d6c2c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -64,8 +64,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, } } -void EmitTuple(const IrArray& tuple, - tensorflow::gtl::ArraySlice operands, +void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { auto* store = b->CreateStore( diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index cf6bf5d0b14..cee211d66f3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -65,8 +65,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(const IrArray& tuple, - tensorflow::gtl::ArraySlice operands, +void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 768105d9e11..0d0fb7946ae 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -141,7 +141,7 @@ ExecutionOptions CreateExecutionOptions( StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); TF_RET_CHECK(proto.has_program_shape()); diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 8f707ea9046..acc8c6d2e05 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -48,7 +48,7 @@ class LocalService : public Service { // compiler is responsible for freeing any memory it allocates this way. StatusOr> CompileExecutable( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& build_options); // Returns the device ordinal that corresponds to the given replica number. diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 4166ef5baf9..b9ec31c4977 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -262,7 +262,7 @@ void MultiOutputFusion::RecomputeReachability() { void MultiOutputFusion::UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, - tensorflow::gtl::ArraySlice instrs_to_update, + absl::Span instrs_to_update, const std::function& skip) { for (auto instr : instrs_to_update) { if (skip != nullptr && skip(instr)) { diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 4c8cb7d379d..d2c52651c4f 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -92,7 +92,7 @@ class MultiOutputFusion : public HloPassInterface { // Update the reachability map after fusing instr1 and instr2. void UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, - tensorflow::gtl::ArraySlice instrs_to_update, + absl::Span instrs_to_update, const std::function& skip = nullptr); // Hook for multi-output fusion along producer-consumer edges. diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 2077b57c05e..2f4b2667c40 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -26,7 +26,6 @@ limitations under the License. namespace xla { -using tensorflow::gtl::ArraySlice; // Transposes the given scatter_indices such that the index_vector_dim becomes // the most-minor dimension. @@ -87,7 +86,7 @@ static StatusOr CanonicalizeScatterIndices( // major dimensions and all the window dimensions appear in the minor // dimensions. static StatusOr PermuteScatterAndWindowDims( - HloInstruction* updates, ArraySlice update_window_dims) { + HloInstruction* updates, absl::Span update_window_dims) { std::vector permutation; const int64 updates_rank = ShapeUtil::Rank(updates->shape()); permutation.reserve(updates_rank); diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index e10c1d9927e..f0e2566a3f9 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -62,10 +62,9 @@ using absl::StrCat; using absl::StrFormat; // Records the arguments used to invoke a computation in an HloSnapshot proto. -Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - se::Stream* stream, TransferManager* transfer_manager, - HloSnapshot* module) { +Status RecordArguments(const absl::Span arguments, + se::Stream* stream, TransferManager* transfer_manager, + HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( @@ -207,8 +206,8 @@ Status Service::ValidateResultShape(const Shape& client_shape, StatusOr>> Service::ResolveAndValidateArguments( - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice stream_executors) { + absl::Span arguments, + absl::Span stream_executors) { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); @@ -242,7 +241,7 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_shapes, + absl::Span argument_shapes, const ExecutionOptions* execution_options) { auto config = absl::make_unique(program_shape); ComputationLayout* computation_layout = @@ -299,7 +298,7 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { @@ -367,12 +366,10 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::ExecuteParallelAndRegisterResult( - tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice>> - arguments, - Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags, - ExecutionProfile* profile) { + absl::Span executables, + absl::Span>> arguments, + Backend* backend, absl::Span device_handles, + absl::Span result_tags, ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector streams; @@ -511,8 +508,7 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice> - arguments, + const absl::Span> arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector streams; @@ -555,8 +551,7 @@ StatusOr Service::ExecuteAndRegisterResult( // TODO(b/69985541): Support profiling also on this path. - std::vector> - replicated_arguments; + std::vector> replicated_arguments; for (const auto& arg : arguments) { replicated_arguments.push_back(arg); } @@ -595,7 +590,7 @@ StatusOr> Service::GetExecutors( StatusOr>> Service::GetArguments( const ExecutionOptions& execution_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 47d196fb2aa..173300d8b6c 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -176,7 +176,7 @@ class Service : public ServiceInterface { // class. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. @@ -191,7 +191,7 @@ class Service : public ServiceInterface { // Prepare the arguments for executing parallel. StatusOr>> GetArguments( const ExecutionOptions& execution_options, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); protected: friend class LocalExecutable; @@ -207,14 +207,14 @@ class Service : public ServiceInterface { // the corresponding replica. StatusOr>> ResolveAndValidateArguments( - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice stream_executors); + absl::Span arguments, + absl::Span stream_executors); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_shapes, + absl::Span argument_shapes, const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. @@ -242,21 +242,17 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice> - arguments, + const absl::Span> arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( - tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice>> - arguments, - Backend* backend, - tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags, - ExecutionProfile* profile); + absl::Span executables, + absl::Span>> arguments, + Backend* backend, absl::Span device_handles, + absl::Span result_tags, ExecutionProfile* profile); // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 45427bba256..26117498621 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -45,7 +45,7 @@ using absl::StrFormat; using absl::StrJoin; // Returns true if no element is present in slice more than once. -bool AllUnique(tensorflow::gtl::ArraySlice slice) { +bool AllUnique(absl::Span slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } @@ -57,11 +57,10 @@ Status ExpectArray(const Shape& shape, absl::string_view op_type) { return Status::OK(); } -Status VerifyReducerShape( - const ProgramShape& reducer_shape, - tensorflow::gtl::ArraySlice init_value_shapes, - tensorflow::gtl::ArraySlice input_element_types, - int64 inputs) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + absl::Span init_value_shapes, + absl::Span input_element_types, + int64 inputs) { if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( "Reduction function must take %d parameters, but " @@ -335,8 +334,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } /* static */ StatusOr ShapeInference::InferConcatOpShape( - tensorflow::gtl::ArraySlice arg_shapes, - const int64 dimension) { + absl::Span arg_shapes, const int64 dimension) { if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } @@ -394,7 +392,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } /* static */ StatusOr ShapeInference::InferAfterAllShape( - tensorflow::gtl::ArraySlice arg_shapes) { + absl::Span arg_shapes) { for (const Shape* arg_shape : arg_shapes) { if (arg_shape->element_type() != TOKEN) { return InvalidArgument( @@ -550,22 +548,22 @@ Status ValidateDotDimensionNumbers( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { // Check that dimension numbers are in range. - auto dims_in_range = - [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, - tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto dims_in_range = [](const int64 rank, + absl::Span contracting_dims, + absl::Span batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; return std::all_of(contracting_dims.begin(), contracting_dims.end(), in_range) && std::all_of(batch_dims.begin(), batch_dims.end(), in_range); }; - tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + absl::Span lhs_contracting_dimensions = AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + absl::Span rhs_contracting_dimensions = AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice lhs_batch_dimensions = + absl::Span lhs_batch_dimensions = AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); - tensorflow::gtl::ArraySlice rhs_batch_dimensions = + absl::Span rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, @@ -577,8 +575,8 @@ Status ValidateDotDimensionNumbers( } // Check that dimension numbers are unique. - auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, - tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto dims_unique = [](absl::Span contracting_dims, + absl::Span batch_dims) -> bool { tensorflow::gtl::FlatSet dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; @@ -748,7 +746,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. @@ -849,7 +847,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); @@ -906,7 +904,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { VLOG(2) << StrFormat( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", HloOpcodeString(opcode), ShapeUtil::HumanString(lhs), @@ -1005,8 +1003,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + HloOpcode opcode, absl::Span operands) { std::vector operand_shapes; operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { @@ -1016,8 +1013,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operand_shapes) { + HloOpcode opcode, absl::Span operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } @@ -1053,9 +1049,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferMapShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span arg_shapes, const ProgramShape& to_apply, + absl::Span dimensions) { if (arg_shapes.empty()) { return InvalidArgument("Map expects at least one argument."); } @@ -1711,7 +1706,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferFftShape( const Shape& in, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { + const absl::Span fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank); @@ -1792,7 +1787,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice operand_shapes) { + absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of cross replica sum")); @@ -1835,7 +1830,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferAllToAllTupleShape( - tensorflow::gtl::ArraySlice operand_shapes) { + absl::Span operand_shapes) { // An Alltoall HLO instruction receives N operands (with the same shape) and // returns a tuple that contains N array shapes. TF_RET_CHECK(!operand_shapes.empty()); @@ -1859,8 +1854,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferReduceShape( - tensorflow::gtl::ArraySlice arg_shapes, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span arg_shapes, + absl::Span dimensions_to_reduce, const ProgramShape& to_apply) { if (arg_shapes.empty()) { return InvalidArgument("Reduce must have at least 2 arguments, has 0"); @@ -1998,9 +1993,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits, - tensorflow::gtl::ArraySlice strides) { + const Shape& arg, absl::Span starts, + absl::Span limits, absl::Span strides) { auto error = [&](const string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " @@ -2062,7 +2056,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( ExpectArray(start_indices_shape, "start indices of dynamic slice")); @@ -2189,7 +2183,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /*static */ StatusOr ShapeInference::InferReverseShape( - const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { + const Shape& operand_shape, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); @@ -2315,7 +2309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + const Shape& operand, absl::Span broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { @@ -2333,8 +2327,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { + const Shape& operand, absl::Span dimensions, + absl::Span new_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = @@ -2366,7 +2360,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { + const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector indices(ShapeUtil::Rank(operand)); @@ -2471,8 +2465,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferCallShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply) { + absl::Span arg_shapes, const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); @@ -2505,8 +2498,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } static Status ValidateGatherDimensionNumbers( - const Shape& input_shape, - tensorflow::gtl::ArraySlice start_indices_shape, + const Shape& input_shape, absl::Span start_indices_shape, const GatherDimensionNumbers& dim_numbers) { if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( @@ -2599,7 +2591,7 @@ static Status ValidateGatherDimensionNumbers( /*static*/ StatusOr ShapeInference::InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { TF_RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand gather op")); TF_RETURN_IF_ERROR( @@ -2709,8 +2701,7 @@ static Status ValidateGatherDimensionNumbers( namespace { Status ValidateScatterDimensionNumbers( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice scatter_indices_shape, + const Shape& operand_shape, absl::Span scatter_indices_shape, const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { // Validate update_window_dims in ScatterDimensionNumbers. if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 235b1a4cf3f..072ada2d8f7 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -55,7 +55,7 @@ class ShapeInference { // given input shapes. static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); static StatusOr InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs); @@ -73,18 +73,15 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. static StatusOr InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operand_shapes); + HloOpcode opcode, absl::Span operand_shapes); static StatusOr InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + HloOpcode opcode, absl::Span operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. static StatusOr InferMapShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice dimensions); + absl::Span arg_shapes, const ProgramShape& to_apply, + absl::Span dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. @@ -116,14 +113,13 @@ class ShapeInference { int64 feature_group_count = 1); // Infers the shape produced by the given FFT type on the given operand. - static StatusOr InferFftShape( - const Shape& in, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + static StatusOr InferFftShape(const Shape& in, FftType fft_type, + absl::Span fft_length); // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice operand_shapes); + absl::Span operand_shapes); // Infers final shape of an Alltoall operation that is created by the xla // builder. @@ -134,7 +130,7 @@ class ShapeInference { // Infers the shape of an HLO all-to-all instruction. static StatusOr InferAllToAllTupleShape( - tensorflow::gtl::ArraySlice operand_shapes); + absl::Span operand_shapes); // Infers the shape of a collective permute operation. static StatusOr InferCollectivePermuteShape(const Shape& shape); @@ -146,8 +142,8 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr InferReduceShape( - tensorflow::gtl::ArraySlice arg_shapes, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span arg_shapes, + absl::Span dimensions_to_reduce, const ProgramShape& to_apply); // Infers the shape produced by applying the given computation to the operand @@ -165,24 +161,23 @@ class ShapeInference { // Infers the shape produced by a reverse operation that reverses the order // of the elements in the given dimensions. - static StatusOr InferReverseShape( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimensions); + static StatusOr InferReverseShape(const Shape& operand_shape, + absl::Span dimensions); // Infers the shape produced by a slice operation spanning from the starts to // the limits in the original shape's dimensions. // // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] - static StatusOr InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits, - tensorflow::gtl::ArraySlice strides); + static StatusOr InferSliceShape(const Shape& arg, + absl::Span starts, + absl::Span limits, + absl::Span strides); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. @@ -213,30 +208,30 @@ class ShapeInference { // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); + const Shape& operand, absl::Span broadcast_sizes); // Infers the shape produced by a reshape operation from the element type of // its operand and the new dimension sizes specified. - static StatusOr InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + static StatusOr InferReshapeShape(const Shape& operand, + absl::Span dimensions, + absl::Span new_sizes); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. static StatusOr InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions); + const Shape& operand, absl::Span dimensions); // Helper that infers the shape produced by performing a concatenate operation // with the given operand shapes. static StatusOr InferConcatOpShape( - tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + absl::Span arg_shapes, int64 dimension); // Infers the shape produced by a kAfterAll. Trivially this shape is always a // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes // and checking operand shapes. This method verifies that the operand shapes // are all TOKENs. static StatusOr InferAfterAllShape( - tensorflow::gtl::ArraySlice arg_shapes); + absl::Span arg_shapes); // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that @@ -266,8 +261,7 @@ class ShapeInference { // Helper that validates the given arg_shapes are compatible with the shape of // the to_apply parameters, and returns the to_apply result shape. static StatusOr InferCallShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply); + absl::Span arg_shapes, const ProgramShape& to_apply); // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. @@ -281,7 +275,7 @@ class ShapeInference { static StatusOr InferGatherShape( const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Helper that validates the given input shape, scatter indices shape, updates // shape, and scatter dimension numbers that constitute a scatter operation, @@ -299,7 +293,7 @@ class ShapeInference { // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Helper for inferring the shape of Clamp ops. static StatusOr InferClampShape(const Shape& min, const Shape& operand, @@ -327,7 +321,7 @@ class ShapeInference { // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); }; diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 4ed8fc6b865..5dbe5a1611e 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -28,7 +28,6 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::gtl::ArraySlice; using ::testing::ContainsRegex; using ::testing::HasSubstr; @@ -58,9 +57,9 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { // Helper that runs reduce shape inference with the input 'arg' and given // dimensions to reduce, and checks the inferred shape is as expected. The // element type here is hard-coded to F32. - void ExpectInferredReduceShape( - const Shape& expected_inferred_shape, const Shape& arg, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + void ExpectInferredReduceShape(const Shape& expected_inferred_shape, + const Shape& arg, + absl::Span dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( {&arg, &f32_}, dimensions_to_reduce, to_apply); @@ -252,7 +251,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, - const tensorflow::gtl::ArraySlice& bcast) { + const absl::Span& bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index f77690a4621..0c393c53a16 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -130,7 +130,7 @@ class TransferManager { // Resets the devices associated with this transfer manager. virtual Status ResetDevices( - tensorflow::gtl::ArraySlice executor) = 0; + absl::Span executor) = 0; // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the @@ -211,8 +211,7 @@ class TransferManager { // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; private: diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index cf00ca102b1..6fed7c76d04 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -360,7 +360,7 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { } Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); + absl::Span operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); points_to_set.AddPointedToBuffer( logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}), diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 10d382e8abc..a32d1f9026e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -72,9 +72,8 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Checks that the given points-to set contains exactly (unordered) the given // LogicalBuffers. - void ExpectHasBuffers( - const PointsToSet::BufferList& points_to_set, - tensorflow::gtl::ArraySlice buffers) { + void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set, + absl::Span buffers) { std::vector vec(buffers.begin(), buffers.end()); EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec)); } @@ -83,7 +82,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // top-level buffers of the given instructions. void ExpectHasTopLevelBuffers( const PointsToSet::BufferList& points_to_set, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { PointsToSet::BufferList buffers; for (auto instruction : instructions) { buffers.push_back(GetBuffer(instruction, /*index=*/{})); @@ -94,7 +93,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Overload which takes a set instead of a vector. void ExpectHasTopLevelBuffers( const PointsToSet::BufferSet& points_to_set, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { ExpectHasTopLevelBuffers( PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()), instructions); @@ -104,8 +103,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // aliases which are exactly (unordered) the given instruction/index pairs. void ExpectHasBufferAliases( const HloInstruction* instruction, const ShapeIndex& index, - tensorflow::gtl::ArraySlice> - expected) { + absl::Span> expected) { const LogicalBuffer* buffer = points_to_analysis_->GetBufferDefinedAt(instruction, index) .ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc index 4a530bb0b20..9ba01ef7a60 100644 --- a/tensorflow/compiler/xla/service/tuple_util.cc +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -40,7 +40,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::AppendSuffix( HloInstruction* input_tuple, - tensorflow::gtl::ArraySlice trailing_values) { + absl::Span trailing_values) { CHECK(ShapeUtil::IsTuple(input_tuple->shape())); HloComputation* computation = input_tuple->parent(); diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h index e5ff9aaa835..bc5aac09f27 100644 --- a/tensorflow/compiler/xla/service/tuple_util.h +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -38,7 +38,7 @@ class TupleUtil { // `input_tuple`. static HloInstruction* AppendSuffix( HloInstruction* input_tuple, - tensorflow::gtl::ArraySlice trailing_values); + absl::Span trailing_values); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 7e4ac92a7c5..c3c2603c7eb 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -211,8 +211,7 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - if (result.ValueOrDie()->data() == - tensorflow::gtl::ArraySlice{false}) { + if (result.ValueOrDie()->data() == absl::Span{false}) { VLOG(2) << "Loop has static trip count of " << trip_count; return trip_count; } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index e8f76ff745a..f90ac91f9d0 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -94,7 +94,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { /*static*/ StatusOr WhileUtil::MakeInstructionsLiveIn( HloInstruction* while_instr, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { CHECK(ShapeUtil::IsTuple(while_instr->shape())); int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index e67636d80f4..b1c4486887a 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -55,7 +55,7 @@ class WhileUtil { // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, - tensorflow::gtl::ArraySlice instructions); + absl::Span instructions); using LoopStateTy = std::vector; using LoopBodyGeneratorTy = std::function( diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 6d016abfde5..9772c06bce3 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -139,8 +139,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, // Constructs and returns the new shape with the given minor_to_major order in // its Layout. StatusOr MakeShapeWithLayoutInternal( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major) { if (dimensions.size() != minor_to_major.size()) { return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); @@ -214,8 +214,8 @@ StatusOr MakeShapeWithLayoutInternal( return program_shape; } -/* static */ Shape ShapeUtil::MakeShape( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { +/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, + absl::Span dimensions) { CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); @@ -223,21 +223,21 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ Shape ShapeUtil::MakeShapeWithLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major) { return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { + PrimitiveType element_type, absl::Span dimensions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); return MakeShapeWithLayout(element_type, dimensions, layout); } /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + PrimitiveType element_type, absl::Span dimensions, int64 max_sparse_elements) { CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); @@ -256,9 +256,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return MakeShapeWithDescendingLayout(shape.element_type(), dims); } -/* static */ void ShapeUtil::PopulateShape( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - Shape* shape) { +/* static */ void ShapeUtil::PopulateShape(PrimitiveType element_type, + absl::Span dimensions, + Shape* shape) { shape->Clear(); shape->set_element_type(element_type); for (int64 dimension : dimensions) { @@ -268,8 +268,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(*shape)); } -/* static */ Shape ShapeUtil::MakeTupleShape( - tensorflow::gtl::ArraySlice shapes) { +/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); result.mutable_tuple_shapes()->Reserve(shapes.size()); @@ -791,7 +790,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); - tensorflow::gtl::ArraySlice padded_dimensions = + absl::Span padded_dimensions = LayoutUtil::PaddedDimensions(shape); if (!padded_dimensions.empty()) { CHECK_EQ(Rank(shape), padded_dimensions.size()); @@ -1115,7 +1114,7 @@ Status ForEachMutableSubshapeHelper( } /* static */ Shape ShapeUtil::PermuteDimensions( - tensorflow::gtl::ArraySlice permutation, const Shape& shape) { + absl::Span permutation, const Shape& shape) { Shape new_shape = shape; new_shape.clear_dimensions(); for (auto dim : Permute(permutation, shape.dimensions())) { @@ -1259,7 +1258,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, - tensorflow::gtl::ArraySlice dimension_mapping) { + absl::Span dimension_mapping) { CHECK(LayoutUtil::HasLayout(input_shape) && LayoutUtil::HasLayout(output_shape)); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 83e58545bf9..8e2b2cb3318 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -147,7 +147,7 @@ class ShapeIndexView { string ToString() const; private: - tensorflow::gtl::ArraySlice indices_; + absl::Span indices_; }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); @@ -328,7 +328,7 @@ class ShapeUtil { static Shape ChangeElementType(const Shape& original, PrimitiveType type); // Creates a tuple shape from a slice of element shapes within the tuple. - static Shape MakeTupleShape(tensorflow::gtl::ArraySlice shapes); + static Shape MakeTupleShape(absl::Span shapes); // Creates an opaque shape. These are generally used for threading a context // into a custom operation. @@ -355,31 +355,29 @@ class ShapeUtil { // Constructs a new shape with the given element type and sequence of // dimensions. static Shape MakeShape(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a Shape with element type corresponding to T and the given // dimensions template - static Shape MakeShapeWithType( - tensorflow::gtl::ArraySlice dimensions) { + static Shape MakeShapeWithType(absl::Span dimensions) { return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dimensions); } // Constructs a new shape with the given minor_to_major order in its Layout. // Returns a value shape such that shape.has_layout(). - static Shape MakeShapeWithLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major); + static Shape MakeShapeWithLayout(PrimitiveType element_type, + absl::Span dimensions, + absl::Span minor_to_major); - static Shape MakeShapeWithSparseLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - int64 max_sparse_elements); + static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, + absl::Span dimensions, + int64 max_sparse_elements); // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( - PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions); + PrimitiveType element_type, absl::Span dimensions); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions @@ -391,8 +389,7 @@ class ShapeUtil { // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, - Shape* shape); + absl::Span dimensions, Shape* shape); // Validates that the provided shape satisfies invariants. static Status ValidateShape(const Shape& shape); @@ -539,7 +536,7 @@ class ShapeUtil { // !HasLayout(shape) || // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), // InversePermutation(permutation)). - static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, + static Shape PermuteDimensions(absl::Span permutation, const Shape& shape); // If we can go from `shape_pre` to `shape_post` by merely inserting or @@ -580,9 +577,9 @@ class ShapeUtil { // to its input and thus may be replaced with a bitcast. // // Precondition: Both input_shape and output_shape have explicit layouts. - static bool TransposeIsBitcast( - const Shape& input_shape, const Shape& output_shape, - tensorflow::gtl::ArraySlice dimension_mapping); + static bool TransposeIsBitcast(const Shape& input_shape, + const Shape& output_shape, + absl::Span dimension_mapping); // Returns whether a reshape from "input_shape" to "output_shape" is a // bitcast. @@ -621,12 +618,12 @@ class ShapeUtil { // continue, or false otherwise. // // visitor_function must be a callable of type - // StatusOr(ArraySlice) or compatible. + // StatusOr(Span) or compatible. template static Status ForEachIndexWithStatus(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { return ForEachIndexInternal(shape, base, count, incr, visitor_function); } @@ -648,13 +645,12 @@ class ShapeUtil { } template - static void ForEachIndex(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + static void ForEachIndex(const Shape& shape, absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { ForEachIndexWithStatus(shape, base, count, incr, - [&](tensorflow::gtl::ArraySlice indices) { + [&](absl::Span indices) { return StatusOr(visitor_function(indices)); }) .IgnoreError(); @@ -676,7 +672,7 @@ class ShapeUtil { template static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { ForEachIndexWithStatus(shape, - [&](tensorflow::gtl::ArraySlice indices) { + [&](absl::Span indices) { return StatusOr(visitor_function(indices)); }) .IgnoreError(); @@ -687,18 +683,18 @@ class ShapeUtil { // matter. // // visitor_function must be a callable of type - // void(ArraySlice) or compatible. + // void(Span) or compatible. template static void ForEachIndexParallel(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { // The parallel version of ForEachIndexInternal can never fail. CHECK(ForEachIndexInternal( shape, base, count, incr, - [&visitor_function](tensorflow::gtl::ArraySlice indexes) - -> StatusOr { + [&visitor_function]( + absl::Span indexes) -> StatusOr { visitor_function(indexes); return true; }, @@ -720,9 +716,9 @@ class ShapeUtil { template static Status ForEachIndexInternal(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function, bool parallel = false) { if (ShapeUtil::IsZeroElementArray(shape)) { diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 7549ba9c780..6ca4085aaf3 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -705,11 +705,10 @@ TEST(ShapeUtilTest, ForEachIndex) { Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); // Increments at every invocation. int invocations = 0; - auto increment_func = - [&invocations](tensorflow::gtl::ArraySlice indexes) { - invocations++; - return true; - }; + auto increment_func = [&invocations](absl::Span indexes) { + invocations++; + return true; + }; std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); @@ -726,8 +725,7 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) { // Increments at every invocation. int invocations = 0; auto increment_func = - [&invocations]( - tensorflow::gtl::ArraySlice indexes) -> StatusOr { + [&invocations](absl::Span indexes) -> StatusOr { if (++invocations == 5) { return Unimplemented("Cannot increment beyond 5."); } @@ -748,7 +746,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel) { Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); int64 output[10][10]; int init = 5; - auto set_func = [&](tensorflow::gtl::ArraySlice indexes) { + auto set_func = [&](absl::Span indexes) { output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; }; diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index 31844abd89a..1c135dda864 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -33,7 +33,7 @@ SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, } SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, - tensorflow::gtl::ArraySlice indices) + absl::Span indices) : SparseIndexArray(max_indices, rank, std::vector(indices.begin(), indices.end())) {} @@ -48,25 +48,24 @@ int64 SparseIndexArray::index_count() const { return indices_.size() / rank_; } -tensorflow::gtl::ArraySlice SparseIndexArray::At( +absl::Span SparseIndexArray::At( int64 sparse_element_number) const { CHECK_GT(rank_, 0); CHECK_GE(sparse_element_number, 0); CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return tensorflow::gtl::ArraySlice( + return absl::Span( indices_.data() + rank_ * sparse_element_number, rank_); } -tensorflow::gtl::MutableArraySlice SparseIndexArray::At( - int64 sparse_element_number) { +absl::Span SparseIndexArray::At(int64 sparse_element_number) { CHECK_GT(rank_, 0); CHECK_GE(sparse_element_number, 0); CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return tensorflow::gtl::MutableArraySlice( - indices_.data() + rank_ * sparse_element_number, rank_); + return absl::Span(indices_.data() + rank_ * sparse_element_number, + rank_); } -void SparseIndexArray::Append(tensorflow::gtl::ArraySlice index) { +void SparseIndexArray::Append(absl::Span index) { CHECK_GT(rank_, 0); CHECK_EQ(index.size(), rank_); indices_.insert(indices_.end(), index.begin(), index.end()); @@ -90,12 +89,12 @@ bool SparseIndexArray::Validate(const Shape& shape) const { if (num_indices < 2) { return true; } - tensorflow::gtl::ArraySlice last = At(0); + absl::Span last = At(0); if (!IndexUtil::IndexInBounds(shape, last)) { return false; } for (int64 n = 1; n < num_indices; ++n) { - tensorflow::gtl::ArraySlice next = At(n); + absl::Span next = At(n); if (!IndexUtil::IndexInBounds(shape, next)) { return false; } diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index 7291705b61b..6c70fd0a06a 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -65,7 +65,7 @@ class SparseIndexArray { SparseIndexArray(int64 max_indices, int64 rank, std::vector indices = {}); SparseIndexArray(int64 max_indices, int64 rank, - tensorflow::gtl::ArraySlice indices); + absl::Span indices); // Returns the number of elements represented by the indices stored in the // array. @@ -73,12 +73,12 @@ class SparseIndexArray { // Returns a slice that refers to the given sparse index number. The argument // must be in the range [0, element_count()). - tensorflow::gtl::ArraySlice At(int64 sparse_element_number) const; - tensorflow::gtl::MutableArraySlice At(int64 sparse_element_number); + absl::Span At(int64 sparse_element_number) const; + absl::Span At(int64 sparse_element_number); // Adds the given index at the end of the array. The new size of the // SparseIndexArray must not exceed `max_indices`. - void Append(tensorflow::gtl::ArraySlice index); + void Append(absl::Span index); // Removes all indices from the array. void Clear(); @@ -96,10 +96,8 @@ class SparseIndexArray { int64 max_indices() const { return max_indices_; } // Returns a pointer to the int64 array that holds the sparse indices. - tensorflow::gtl::MutableArraySlice mutable_data() { - return absl::MakeSpan(indices_); - } - tensorflow::gtl::ArraySlice data() const { return indices_; } + absl::Span mutable_data() { return absl::MakeSpan(indices_); } + absl::Span data() const { return indices_; } // Sorts this sparse index array along with the set of corresponding values. // The indices and values are sorted in the lexicographic order of the @@ -117,7 +115,7 @@ class SparseIndexArray { // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; // template - void SortWithValues(tensorflow::gtl::MutableArraySlice values); + void SortWithValues(absl::Span values); private: std::vector indices_; @@ -126,8 +124,7 @@ class SparseIndexArray { }; template -void SparseIndexArray::SortWithValues( - tensorflow::gtl::MutableArraySlice values) { +void SparseIndexArray::SortWithValues(absl::Span values) { int64 num_elements = index_count(); CHECK_EQ(values.size(), num_elements); std::vector sort_order; diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 577fd1ab3b9..55dcf2817b6 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -41,7 +41,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::ArraySlice; class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: @@ -433,8 +432,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { class IntegerDivideOpTest : public ArrayElementwiseOpTest { protected: template - void TestDivRem(ArraySlice dividends, ArraySlice divisors, - ArraySlice quotients, ArraySlice remainders) { + void TestDivRem(absl::Span dividends, absl::Span divisors, + absl::Span quotients, + absl::Span remainders) { { XlaBuilder builder(TestName()); XlaOp dividend; diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 1d28e85b165..fe4267c73bd 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -53,10 +53,11 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { } } - std::unique_ptr MakeR3Data( - tensorflow::gtl::ArraySlice bounds, - tensorflow::gtl::ArraySlice minor_to_major, Shape* r3_shape, - Array3D* r3_array, float start, float end, int seed) { + std::unique_ptr MakeR3Data(absl::Span bounds, + absl::Span minor_to_major, + Shape* r3_shape, + Array3D* r3_array, float start, + float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( @@ -66,10 +67,11 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { return r3_global_data; } - std::unique_ptr MakeR2Data( - tensorflow::gtl::ArraySlice bounds, - tensorflow::gtl::ArraySlice minor_to_major, Shape* r2_shape, - Array2D* r2_array, float start, float end, int seed) { + std::unique_ptr MakeR2Data(absl::Span bounds, + absl::Span minor_to_major, + Shape* r2_shape, + Array2D* r2_array, float start, + float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( @@ -348,7 +350,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); - auto Each = ([&](tensorflow::gtl::ArraySlice indices, float* value) { + auto Each = ([&](absl::Span indices, float* value) { float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], indices[1] % spec.input_bounds[1], indices[2] % spec.input_bounds[2]); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index f1ab83df821..8a236db0ff2 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -95,15 +95,14 @@ string ClientLibraryTestBase::TestName() const { } StatusOr> ClientLibraryTestBase::Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -115,7 +114,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( } StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); @@ -124,8 +123,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( StatusOr> ClientLibraryTestBase::ExecuteAndTransferReference( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -138,7 +136,7 @@ ClientLibraryTestBase::ExecuteAndTransferReference( } string ClientLibraryTestBase::ExecuteToString( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { auto computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status().ToString(); @@ -156,7 +154,7 @@ string ClientLibraryTestBase::ExecuteToString( void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); @@ -164,15 +162,14 @@ void ClientLibraryTestBase::ComputeAndCompareR1( void ClientLibraryTestBase::ComputeAndCompareLiteral( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout) { + absl::Span arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } void ClientLibraryTestBase::ComputeAndCompareLiteral( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + absl::Span arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, error, shape_with_layout)); @@ -180,7 +177,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output) { // Try with no layout requirement. @@ -205,7 +202,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output, const Shape* output_with_layout) { @@ -252,10 +249,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( // Every argument has an assigned layout. TF_ASSIGN_OR_RETURN( auto actual, - ExecuteAndTransfer( - computation, - tensorflow::gtl::ArraySlice(arguments_with_layout), - output_with_layout)); + ExecuteAndTransfer(computation, + absl::Span(arguments_with_layout), + output_with_layout)); string error_message = "Test with input layouts: "; for (const auto& str : layout_strings) { absl::StrAppend(&error_message, str, " "); @@ -269,7 +265,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments_passed_in, + absl::Span arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); @@ -329,8 +325,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments_passed_in, - ErrorSpec error, const Shape* shape_with_layout) { + absl::Span arguments_passed_in, ErrorSpec error, + const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); @@ -386,7 +382,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( void ClientLibraryTestBase::ComputeAndCompareR1U8( XlaBuilder* builder, absl::string_view expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { @@ -405,7 +401,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { @@ -417,7 +413,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { @@ -428,7 +424,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( } void ClientLibraryTestBase::ComputeAndCompare( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { @@ -440,8 +436,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } void ClientLibraryTestBase::ComputeAndCompare( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, - ErrorSpec error) { + XlaBuilder* builder, absl::Span arguments, ErrorSpec error) { auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { @@ -454,7 +449,7 @@ void ClientLibraryTestBase::ComputeAndCompare( StatusOr, std::unique_ptr>> ClientLibraryTestBase::ComputeValueAndReference( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's // into a vector to keep the data alive on the service until the end of this // function. diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index ac96d3e325b..77954dd7f75 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -49,8 +49,8 @@ namespace xla { // use_bfloat16_params with that value. Returns the result. template std::vector ExpandUseBfloat16( - tensorflow::gtl::ArraySlice use_bfloat16_params, - tensorflow::gtl::ArraySlice specs) { + absl::Span use_bfloat16_params, + absl::Span specs) { std::vector expanded; for (bool use_bfloat16 : use_bfloat16_params) { for (const auto& spec : specs) { @@ -93,15 +93,15 @@ class ClientLibraryTestBase : public ::testing::Test { // execution options. Modify execution_options_ in your test if you want to // customize the options. StatusOr> Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); + XlaBuilder* builder, absl::Span arguments); StatusOr> ExecuteAndTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); StatusOr> ExecuteAndTransfer( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_output_layout = nullptr); // This executes the computation via the reference client (which connects a @@ -109,13 +109,13 @@ class ClientLibraryTestBase : public ::testing::Test { // computation. StatusOr> ExecuteAndTransferReference( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_output_layout = nullptr); // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. string ExecuteToString(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are @@ -125,102 +125,98 @@ class ClientLibraryTestBase : public ::testing::Test { // for integral types without the ErrorSpec parameter. template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span expected, + absl::Span arguments); template void ComputeAndCompareR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span expected, + absl::Span arguments, ErrorSpec error); // As above, but uses a bitmap to hold the predicate vector to avoid // deficiencies of vector. void ComputeAndCompareR1(XlaBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. - void ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout = nullptr); - void ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + ErrorSpec error, + const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_layout = nullptr); Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + absl::Span arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // Compare the result of the computation to a strings. In XLA strings are // represented using rank-1 U8 shapes. - void ComputeAndCompareR1U8( - XlaBuilder* builder, absl::string_view expected, - tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected, + absl::Span arguments); // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. - void ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - void ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, + absl::Span arguments); + void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + ErrorSpec error); // Convenience method for running a built computation and comparing the result // with the reference result. void ComputeAndCompare(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); void ComputeAndCompare(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments, - ErrorSpec error); + absl::Span arguments, ErrorSpec error); // Create scalar operations for use in reductions. XlaComputation CreateScalarRelu(); @@ -337,7 +333,7 @@ class ClientLibraryTestBase : public ::testing::Test { // converted to bfloat16. template std::unique_ptr CreateR1Parameter( - tensorflow::gtl::ArraySlice values, int64 parameter_number, + absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array @@ -381,7 +377,7 @@ class ClientLibraryTestBase : public ::testing::Test { // actual). StatusOr, std::unique_ptr>> ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); Client* client_; Client* ref_client_; // To compute reference result. @@ -390,12 +386,12 @@ class ClientLibraryTestBase : public ::testing::Test { private: Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output); Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output, const Shape* output_with_layout = nullptr); @@ -415,7 +411,7 @@ class ClientLibraryTestBase : public ::testing::Test { template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -425,7 +421,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -440,8 +436,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( template void ClientLibraryTestBase::ComputeAndCompareR1( - XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span expected, + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -450,8 +446,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( template void ClientLibraryTestBase::ComputeAndCompareR1( - XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + XlaBuilder* builder, absl::Span expected, + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -467,7 +463,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -477,7 +473,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -493,7 +489,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -503,7 +499,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -519,7 +515,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -529,7 +525,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -558,7 +554,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( - tensorflow::gtl::ArraySlice values, int64 parameter_number, + absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = LiteralUtil::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 7c52c9fbbb5..25f6bfd8430 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -38,10 +38,9 @@ namespace { class CompilationCacheTest : public ClientLibraryTestBase { public: - void ExecuteComputationR0F32( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, float expected_result, - bool expect_cache_hit) { + void ExecuteComputationR0F32(const XlaComputation& computation, + absl::Span arguments, + float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; std::unique_ptr result = client_ @@ -56,7 +55,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { void ExecuteComputationR2F32( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, std::initializer_list> expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 50a9ebc1e99..526626c1ddd 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -54,7 +54,7 @@ class CopyOpTest : public HloTestBase { void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4, - tensorflow::gtl::ArraySlice permutation); + absl::Span permutation); }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { @@ -187,9 +187,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { LiteralTestUtil::ExpectR3EqualArray3D(a, *result); } -void CopyOpTest::TestCopyConstantLayoutR4( - size_t n1, size_t n2, size_t n3, size_t n4, - tensorflow::gtl::ArraySlice permutation) { +void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, + size_t n4, + absl::Span permutation) { Array4D a(n1, n2, n3, n4); for (size_t i = 0; i < n1; ++i) { for (size_t j = 0; j < n2; ++j) { diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 5f234f36a85..1ea94168812 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -36,7 +36,7 @@ class DeallocationTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 2db6503afab..e1435cf8abd 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -42,7 +42,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 7f6f203a1ba..9bf3767ca3e 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -114,14 +114,14 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values_int, + void RunR1(absl::Span input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - tensorflow::gtl::ArraySlice expected_values_int) { + absl::Span expected_values_int) { // bfloat16 has explicit constructors, so it does not implicitly convert the // way built-in types do, which is why we can't take the parameter as an - // ArraySlice. We also can't convert it to a vector, because - // vector is special so that it cannot be an ArraySlice, which + // Span. We also can't convert it to a vector, because + // vector is special so that it cannot be a Span, which // is what the code below wants. So instead we do this. Literal input_values = std::move(*LiteralUtil::CreateR1(input_values_int) @@ -385,10 +385,10 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values_int, - tensorflow::gtl::ArraySlice update_values_int, + void RunR1(absl::Span input_values_int, + absl::Span update_values_int, const std::vector slice_starts, - tensorflow::gtl::ArraySlice expected_values_int) { + absl::Span expected_values_int) { Literal input_values = std::move(*LiteralUtil::CreateR1(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 4a835a8e219..313d10566ef 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -37,8 +37,8 @@ class FloorCeilTest : public ClientLibraryTestBase { }; // Runs a computation and comparison on expected vs f(input) - void TestR1F32(tensorflow::gtl::ArraySlice input, - tensorflow::gtl::ArraySlice expected, Function f) { + void TestR1F32(absl::Span input, + absl::Span expected, Function f) { LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}"; XlaBuilder builder(TestName()); auto c = ConstantR1(&builder, input); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 15a9d55bfe3..cbfe147953b 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -48,7 +48,6 @@ limitations under the License. #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -using tensorflow::gtl::ArraySlice; namespace xla { namespace { @@ -113,7 +112,7 @@ class FusionTest : public HloTestBase { hlos[0] = builder.AddInstruction(std::move(root_hlo)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction( - ArraySlice(hlos).subspan(0, Arity + 1), + absl::Span(hlos).subspan(0, Arity + 1), HloInstruction::FusionKind::kLoop); auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); @@ -127,12 +126,12 @@ class FusionTest : public HloTestBase { private: template - T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice xs); + T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span xs); }; template <> float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { + absl::Span xs) { switch (opcode) { case HloOpcode::kAdd: return xs[0] + xs[1]; @@ -157,7 +156,7 @@ float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, template <> bool FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { + absl::Span xs) { switch (opcode) { case HloOpcode::kEq: return xs[0] == xs[1]; diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 205d417f0c6..6d634980449 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -34,8 +34,7 @@ class GatherOperationTest : public HloTestBase { RunTest(hlo_text, {operand, start_indices}); } - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 51450314b61..1115e50fe31 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -126,9 +126,8 @@ INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = - std::function)>; +using BinaryBuildFuncTy = std::function)>; struct BinaryOpTestParam { std::function compute_func; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 93ea144438a..1c42a19414f 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -44,7 +44,6 @@ namespace { using absl::optional; using absl::string_view; -using tensorflow::gtl::ArraySlice; constexpr char kInterpreter[] = "interpreter"; @@ -130,14 +129,12 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { } StatusOr> HloTestBase::Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr module, absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments); } std::unique_ptr HloTestBase::ExecuteNoHloPasses( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr module, absl::Span arguments) { return test_runner_ .Execute(std::move(module), arguments, /*run_hlo_passes=*/false) @@ -145,8 +142,7 @@ std::unique_ptr HloTestBase::ExecuteNoHloPasses( } std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr module, absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } @@ -169,7 +165,8 @@ StatusOr> HloTestBase::MakeReferenceModule( } StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, bool run_hlo_passes, const std::function& reference_preprocessor) { TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status()); @@ -188,7 +185,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -201,7 +199,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, const std::function& reference_preprocessor) { auto result = diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 06bcc397417..1468930e9cf 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -114,18 +114,15 @@ class HloTestBase : public ::testing::Test { // Executes the given module and return the result as a Literal. StatusOr> Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + std::unique_ptr module, absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. std::unique_ptr ExecuteNoHloPasses( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + std::unique_ptr module, absl::Span arguments); std::unique_ptr ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + std::unique_ptr module, absl::Span arguments); // Executes the given hlo module on two backends and compares results. // @@ -140,7 +137,7 @@ class HloTestBase : public ::testing::Test { // modified. ::testing::AssertionResult RunAndCompare( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -149,7 +146,7 @@ class HloTestBase : public ::testing::Test { // optimization. ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -261,7 +258,7 @@ class HloTestBase : public ::testing::Test { // error happens before the results are computed, returns the error status. StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const absl::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor); }; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 3dad91951e7..342a0aa735d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -62,7 +62,7 @@ class LiteralTestUtil { static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); template - static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, + static void ExpectR1Equal(absl::Span expected, const LiteralSlice& actual); template static void ExpectR2Equal( @@ -102,7 +102,7 @@ class LiteralTestUtil { const ErrorSpec& error); template - static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, + static void ExpectR1Near(absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error); template @@ -160,7 +160,7 @@ template template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { + absl::Span expected, const LiteralSlice& actual) { EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); } @@ -206,7 +206,7 @@ template template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, + absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error) { EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); } diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 948b60061e2..a8c68fc7fdb 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -156,7 +156,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -164,7 +164,7 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { return ExecuteLocally(computation, arguments, build_options, run_options) @@ -173,14 +173,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()); } StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { std::vector argument_layouts(arguments.size()); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index b4477e9a6b2..cb74ff71d45 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -93,19 +93,19 @@ class LocalClientTestBase : public ::testing::Test { // options. StatusOr ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); StatusOr ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); ScopedShapedBuffer ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); ScopedShapedBuffer ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index b0a324f6fc9..fc0ba28b128 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -47,7 +47,6 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::gtl::ArraySlice; class MultiOutputFusionTest : public HloTestBase { protected: diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 2fc7f816b56..58539e6b061 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -31,7 +31,7 @@ class PredTest : public ClientLibraryTestBase { protected: void TestCompare(bool lhs, bool rhs, bool expected, std::function)> + absl::Span)> op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 326e13b3867..ee6bdfd0063 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -37,8 +37,7 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - std::unique_ptr UniformTest(T a, T b, - tensorflow::gtl::ArraySlice dims, + std::unique_ptr UniformTest(T a, T b, absl::Span dims, int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution @@ -50,8 +49,9 @@ class PrngTest : public ClientLibraryTestBase { }; template -std::unique_ptr PrngTest::UniformTest( - T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { +std::unique_ptr PrngTest::UniformTest(T a, T b, + absl::Span dims, + int64 seed) { XlaBuilder builder(TestName()); RngUniform( ConstantR0(&builder, a), ConstantR0(&builder, b), @@ -61,7 +61,7 @@ std::unique_ptr PrngTest::UniformTest( auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { + actual->EachCell([=](absl::Span, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); @@ -117,7 +117,7 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); result->Literal::EachCell( - [&](tensorflow::gtl::ArraySlice, bfloat16 value) { + [&](absl::Span, bfloat16 value) { int64 index = static_cast((value - low) / interval); counts[index]++; }); @@ -149,8 +149,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); - actual->EachCell([&counts](tensorflow::gtl::ArraySlice, - int32 value) { ++counts[value]; }); + actual->EachCell( + [&counts](absl::Span, int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 346f7024886..51d429ddeae 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -115,8 +115,7 @@ class ReduceTest : public ClientLibraryTestBase { ErrorSpec(0.001)); } - void RunR1ToR0PredTest(bool and_reduce, - tensorflow::gtl::ArraySlice input_data) { + void RunR1ToR0PredTest(bool and_reduce, absl::Span input_data) { const int element_count = input_data.size(); XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); @@ -261,8 +260,8 @@ class ReduceTest : public ClientLibraryTestBase { void ComputeAndCompareGeneric( typename std::enable_if::value, XlaBuilder>::type* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span expected, + absl::Span arguments) { ComputeAndCompareR1(builder, expected, arguments, ErrorSpec(0.01, 1e-4)); } @@ -271,8 +270,8 @@ class ReduceTest : public ClientLibraryTestBase { void ComputeAndCompareGeneric( typename std::enable_if::value, XlaBuilder>::type* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span expected, + absl::Span arguments) { ComputeAndCompareR1(builder, expected, arguments); } @@ -304,7 +303,7 @@ class ReduceTest : public ClientLibraryTestBase { client_->TransferToServer(*input_literal).ConsumeValueOrDie(); // NativeT can be bool, and std::vector does not convert to - // ArraySlice. + // Span. std::unique_ptr expected(new NativeT[cols]); for (int64 colno = 0; colno < cols; ++colno) { NativeT column_result = initial_value; @@ -316,7 +315,7 @@ class ReduceTest : public ClientLibraryTestBase { } ComputeAndCompareGeneric( - &builder, tensorflow::gtl::ArraySlice(expected.get(), cols), + &builder, absl::Span(expected.get(), cols), {input_global_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 60167619a4e..679ee4d482f 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -70,8 +70,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } void ReduceWindowAdd(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_); @@ -81,8 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, } void ReduceWindowMax(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); @@ -92,8 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, } void ReduceWindowMin(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); @@ -1303,7 +1303,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = - LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -1327,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; auto expected = ReferenceUtil::ReduceWindow1DGeneric( - /*operand=*/tensorflow::gtl::ArraySlice(input_vector), + /*operand=*/absl::Span(input_vector), /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 382d1b1ae74..ec4a2bcd3dd 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -689,9 +689,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -711,9 +710,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -734,9 +732,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -747,7 +744,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); - input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { + input.Each([&](absl::Span indices, float* cell) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); @@ -762,7 +759,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); input_array.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + [&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( @@ -842,9 +839,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::vector bounds = {2, 2, 2, 2}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -871,9 +867,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::vector bounds = {1, 1, 250, 300}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -900,9 +895,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::vector bounds = {5, 5, 1, 10}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -930,9 +924,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::vector bounds = {5, 5, 10, 1}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -959,9 +952,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::vector bounds = {3, 3, 1, 3}; std::vector new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index c755ff63c90..74ded82ddfa 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -39,8 +39,8 @@ static std::array use_bfloat16_params{false}; #endif struct ReverseSpec { - tensorflow::gtl::ArraySlice input_dims; - tensorflow::gtl::ArraySlice reversal; + absl::Span input_dims; + absl::Span reversal; bool use_bfloat16; string ToTestCaseName() const { @@ -91,17 +91,16 @@ TEST_P(FloatReverseTest, Reverses) { std::unique_ptr expected = input_literal->CloneToUnique(); std::vector output_indices(spec.input_dims.size()); - expected->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float) { - for (int64 i = 0; i < indices.size(); ++i) { - output_indices[i] = indices[i]; - } - float value = input_literal->Get(indices); - for (int64 dim : spec.reversal) { - output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; - } - expected->Set(output_indices, value); - }); + expected->EachCell([&](absl::Span indices, float) { + for (int64 i = 0; i < indices.size(); ++i) { + output_indices[i] = indices[i]; + } + float value = input_literal->Get(indices); + for (int64 dim : spec.reversal) { + output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; + } + expected->Set(output_indices, value); + }); ComputeAndCompareLiteral(&builder, *expected, {}); } diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index a620fe19085..d5a579cfaf9 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -47,8 +47,7 @@ class RoundTripPackedLiteralTest : public ClientLibraryTestBase { TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { string data(sizeof(float) * 2, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 2); + absl::Span floats(tensorflow::bit_cast(data.data()), 2); floats[0] = 42.0; floats[1] = 24.0; @@ -70,8 +69,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { string data(sizeof(float) * 4, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 4); + absl::Span floats(tensorflow::bit_cast(data.data()), 4); // With x as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=0,x=1 @@ -105,8 +103,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { string data(sizeof(float) * 4, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 4); + absl::Span floats(tensorflow::bit_cast(data.data()), 4); // With y as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=1,x=0 diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index cf2d453f43c..d60b8e969d3 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -46,9 +46,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase { // A template for building and running a binary comparison test. template void TestCompare(NativeT lhs, NativeT rhs, bool expected, - std::function)> - op) { + const std::function)>& op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); @@ -58,9 +57,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase { template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - std::function)> - op) { + const std::function)>& op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 99eeb12e2bd..1858dcea612 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -32,8 +32,7 @@ class ScatterTest : public HloTestBase { RunTest(hlo_text, {operand, scatter_indices, updates}); } - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index e3d4f98dd74..f737b5158b3 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -42,8 +42,8 @@ struct SelectAndScatterTestParam { std::vector operand_shape; std::vector source_shape; Padding padding_type; - tensorflow::gtl::ArraySlice window_dimensions; - tensorflow::gtl::ArraySlice window_strides; + absl::Span window_dimensions; + absl::Span window_strides; }; class SelectAndScatterTest diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 69585ae39a7..01252a0cf15 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -194,7 +194,7 @@ class SliceR1Test : public ClientLibraryTestBase, protected: template void Run(const R1Spec& spec) { - // This can't be an std::vector, since you can't grab an ArraySlice of a + // This can't be an std::vector, since you can't grab a Span of a // vector. absl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 60ada58b2a0..c20a7c8fe49 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -183,8 +183,8 @@ StatusOr> MakeFakeLiteralInternal( break; case PRED: { std::uniform_int_distribution generator(0, 1); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { + TF_CHECK_OK( + literal->Populate([&](absl::Span /*indices*/) { return generator(*engine); })); break; @@ -236,8 +236,8 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomIndex( - tensorflow::gtl::ArraySlice index_space, std::minstd_rand0* engine) { +std::unique_ptr MakeRandomIndex(absl::Span index_space, + std::minstd_rand0* engine) { std::vector start_indices(index_space.size()); if (engine != nullptr) { for (int i = 0; i < index_space.size(); ++i) { @@ -294,7 +294,7 @@ std::vector FindConstrainedUses( // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). StatusOr> CreateLiteralForConstrainedUses( - const tensorflow::gtl::ArraySlice constrained_uses, + const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { std::vector index_space; bool no_duplicates = false; diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index c101cd2d201..f2b3b49015c 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -507,7 +507,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { {{10011, 20022}, {30031, 40042}}}); auto prod = absl::make_unique(sum->shape()); ASSERT_TRUE(prod->Populate( - [&sum](tensorflow::gtl::ArraySlice indexes) { + [&sum](absl::Span indexes) { return sum->Get(indexes) * (indexes[indexes.size() - 1] == 0 ? complex64(1, 2) diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index e3a06257805..7fd42944deb 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -84,7 +84,7 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, gtl::FlatMap* parsed_results, - tensorflow::gtl::ArraySlice opcodes_to_ignore = {}) { + absl::Span opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 00147015a6b..395e3b69ac6 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -46,8 +46,7 @@ namespace xla { Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( - [f_ptr, &status](tensorflow::gtl::ArraySlice indices, - const string& value) { + [f_ptr, &status](absl::Span indices, const string& value) { if (!status.ok()) { return; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index d15b71b7925..e60b32cb548 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -46,7 +46,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -77,7 +77,7 @@ int main(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); + absl::Span args(argv, argc); args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index c446b27a040..b9073a3f432 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -59,7 +59,7 @@ class OperationDumper : public DfsHloVisitorWithDefault { string path_; }; -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -104,7 +104,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); + absl::Span args(argv, argc); args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index d86a4474b32..959fc8311b1 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -34,7 +34,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { +void RealMain(absl::Span args, bool compile) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -102,7 +102,7 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; - tensorflow::gtl::ArraySlice args(argv, argc); + absl::Span args(argv, argc); args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args, compile); return 0; diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index bd8b89542ff..ef4d09a04b8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -45,7 +45,7 @@ using tensorflow::Env; namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -78,7 +78,7 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); + absl::Span args(argv, argc); args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index e826d6fa936..2695e2e4b19 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -253,7 +253,7 @@ StatusOr ParseInputFile(const string& filename, return InvalidArgument("Could not parse %s.", filename); } -int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { +int RealMain(absl::Span args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; @@ -344,7 +344,7 @@ int main(int argc, char** argv) { LOG(QFATAL) << usage; } - tensorflow::gtl::ArraySlice args(argv, argc); + absl::Span args(argv, argc); args.remove_prefix(1); // Pop off the binary name, argv[0] return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 10e7202acfb..f7eb4a79ca5 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -45,7 +45,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -66,7 +66,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); + absl::Span args(argv, argc); args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 0f607a0c8af..68cab7387cf 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -76,7 +76,7 @@ string Reindent(absl::string_view original, }); } -bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { +bool IsPermutation(absl::Span permutation, int64 rank) { if (rank != permutation.size()) { return false; } @@ -90,7 +90,7 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { } std::vector InversePermutation( - tensorflow::gtl::ArraySlice input_permutation) { + absl::Span input_permutation) { DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { @@ -99,8 +99,8 @@ std::vector InversePermutation( return output_permutation; } -std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, - tensorflow::gtl::ArraySlice p2) { +std::vector ComposePermutations(absl::Span p1, + absl::Span p2) { CHECK_EQ(p1.size(), p2.size()); std::vector output; for (size_t i = 0; i < p1.size(); ++i) { @@ -109,7 +109,7 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { +bool IsIdentityPermutation(absl::Span permutation) { for (int64 i = 0; i < permutation.size(); ++i) { if (permutation[i] != i) { return false; @@ -130,7 +130,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) { } PaddingConfig MakeEdgePaddingConfig( - tensorflow::gtl::ArraySlice> padding) { + absl::Span> padding) { PaddingConfig padding_config; for (const std::pair& dim : padding) { auto dimension = padding_config.add_dimensions(); @@ -207,14 +207,13 @@ void LogLines(int sev, absl::string_view text, const char* fname, int lineno) { } } -int64 Product(tensorflow::gtl::ArraySlice xs) { +int64 Product(absl::Span xs) { return std::accumulate(xs.begin(), xs.end(), static_cast(1), std::multiplies()); } -std::vector> CommonFactors( - tensorflow::gtl::ArraySlice a, - tensorflow::gtl::ArraySlice b) { +std::vector> CommonFactors(absl::Span a, + absl::Span b) { CHECK_EQ(Product(a), Product(b)); if (0 == Product(a)) { return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())}; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index c8b48c5ab45..6f7d97e30e2 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -101,74 +101,99 @@ struct ScopedLoggingTimer { uint64 start_micros; }; -// Given a vector, returns a MutableArraySlice that points at its +// Given a vector, returns a Span that points at its // internals. // // Warning: if the vector is updated its storage pointer may change, so use this // with caution (ideally in limited scopes with temporary lifetimes). template -tensorflow::gtl::MutableArraySlice MutableByteSlice(std::vector* v) { - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(v->data()), v->size() * sizeof(T)); +absl::Span MutableByteSlice(std::vector* v) { + return absl::Span(reinterpret_cast(v->data()), + v->size() * sizeof(T)); } // Turns an immutable slice of type T into an immutable slice of bytes with the // same byte size. template -tensorflow::gtl::ArraySlice CastToByteSlice( - tensorflow::gtl::ArraySlice slice) { - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size() * sizeof(T)); +absl::Span CastToByteSlice(absl::Span slice) { + return absl::Span(reinterpret_cast(slice.data()), + slice.size() * sizeof(T)); } // Casts a byte slice to a non-byte type T, checking that the original slice // length is a multiple of sizeof(T). template -tensorflow::gtl::ArraySlice CastByteSlice( - tensorflow::gtl::ArraySlice slice) { +absl::Span CastByteSlice(absl::Span slice) { CHECK_EQ(0, slice.size() % sizeof(T)); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size() / sizeof(T)); + return absl::Span(reinterpret_cast(slice.data()), + slice.size() / sizeof(T)); } // Convenience function to force a vector to convert to an immutable slice. template -tensorflow::gtl::ArraySlice AsSlice(const std::vector& v) { - return tensorflow::gtl::ArraySlice(v); +absl::Span AsSlice(const std::vector& v) { + return absl::Span(v); } -// Converts a mutable vector pointer into a MutableArraySlice of the same +// Converts a mutable vector pointer into a Span of the same // type. template -tensorflow::gtl::MutableArraySlice AsMutableSlice(std::vector* v) { - return tensorflow::gtl::MutableArraySlice(v->data(), v->size()); +absl::Span AsMutableSlice(std::vector* v) { + return absl::Span(v->data(), v->size()); } // xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source. // Wrapper function that gives an int64 array slice view of a repeated int64 // protobuf field. -static inline tensorflow::gtl::ArraySlice AsInt64Slice( +static inline absl::Span AsInt64Slice( const tensorflow::protobuf::RepeatedField& v) { - tensorflow::gtl::ArraySlice slice(v); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size()); + absl::Span slice(v); + return absl::Span(reinterpret_cast(slice.data()), + slice.size()); } // As above, but for uint64 types. -static inline tensorflow::gtl::ArraySlice AsUInt64Slice( +static inline absl::Span AsUInt64Slice( const tensorflow::protobuf::RepeatedField& v) { - tensorflow::gtl::ArraySlice slice(v); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size()); + absl::Span slice(v); + return absl::Span(reinterpret_cast(slice.data()), + slice.size()); +} + +// Compares two containers for equality. Returns true iff the two containers +// have the same size and all their elements compare equal using their +// operator==. Like std::equal, but forces size equality. +template +bool ContainersEqual(const Container1T& c1, const Container2T& c2) { + return ((c1.size() == c2.size()) && + std::equal(std::begin(c1), std::end(c1), std::begin(c2))); +} + +template +bool ContainersEqual(const Container1T& c1, + std::initializer_list il) { + absl::Span c2{il}; + return ContainersEqual(c1, c2); +} + +// Compares two containers for equality. Returns true iff the two containers +// have the same size and all their elements compare equal using the predicate +// p. Like std::equal, but forces size equality. +template +bool ContainersEqual(const Container1T& c1, const Container2T& c2, + PredicateT p) { + return ((c1.size() == c2.size()) && + std::equal(std::begin(c1), std::end(c1), std::begin(c2), p)); } // Performs a copy of count values from src to dest, using different strides for // source and destination. The source starting index is src_base, while the // destination one is dest_base. template -void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, - int64 dest_stride, tensorflow::gtl::ArraySlice src, - int64 src_base, int64 src_stride, int64 count) { +void StridedCopy(absl::Span dest, int64 dest_base, int64 dest_stride, + absl::Span src, int64 src_base, int64 src_stride, + int64 count) { for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) { dest[dest_base] = static_cast(src[src_base]); } @@ -258,7 +283,7 @@ Status ResourceExhaustedStrCat(Args&&... concat) { string Reindent(absl::string_view original, absl::string_view indentation); // Checks whether permutation is a permutation of the [0, rank) integer range. -bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); +bool IsPermutation(absl::Span permutation, int64 rank); // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. @@ -268,9 +293,9 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); // 2. permutation.size() == input.size(). template std::vector Permute( - tensorflow::gtl::ArraySlice permutation, const Container& input) { + absl::Span permutation, const Container& input) { using T = typename Container::value_type; - tensorflow::gtl::ArraySlice data(input); + absl::Span data(input); CHECK(IsPermutation(permutation, data.size())); std::vector output(data.size()); for (size_t i = 0; i < permutation.size(); ++i) { @@ -281,14 +306,14 @@ std::vector Permute( // Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i. std::vector InversePermutation( - tensorflow::gtl::ArraySlice input_permutation); + absl::Span input_permutation); // Composes two permutations: output[i] = p1[p2[i]]. -std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, - tensorflow::gtl::ArraySlice p2); +std::vector ComposePermutations(absl::Span p1, + absl::Span p2); // Returns true iff permutation == {0, 1, 2, ...}. -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation); +bool IsIdentityPermutation(absl::Span permutation); template int64 PositionInContainer(const Container& container, int64 value) { @@ -342,7 +367,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank); // Returns a PaddingConfig object where 'padding' contains // (low edge padding, high edge padding) pairs for each dimension. PaddingConfig MakeEdgePaddingConfig( - tensorflow::gtl::ArraySlice> padding); + absl::Span> padding); // Returns true if the padding configuration has at least one dimension with // non-zero interior padding. @@ -409,7 +434,7 @@ std::unique_ptr unique_ptr_static_cast(std::unique_ptr ptr) { return std::unique_ptr(static_cast(ptr.release())); } -int64 Product(tensorflow::gtl::ArraySlice xs); +int64 Product(absl::Span xs); // Returns the start indices of consecutive non-overlapping subsequences of `a` // and `b` with the same product, i.e. `(i, j)` so @@ -422,8 +447,8 @@ int64 Product(tensorflow::gtl::ArraySlice xs); // // If the given shapes have non-zero size, returns the bounds of the shortest // possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`. -std::vector> CommonFactors( - tensorflow::gtl::ArraySlice a, tensorflow::gtl::ArraySlice b); +std::vector> CommonFactors(absl::Span a, + absl::Span b); // Removes illegal characters from filenames. string SanitizeFileName(string file_name); @@ -445,7 +470,7 @@ void EraseAt(C* c, int64 index) { } template -std::vector ArraySliceToVector(tensorflow::gtl::ArraySlice slice) { +std::vector ArraySliceToVector(absl::Span slice) { return std::vector(slice.begin(), slice.end()); } diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 268dc5db01a..83735d239fb 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -24,7 +24,7 @@ limitations under the License. namespace xla { namespace window_util { -Window MakeWindow(tensorflow::gtl::ArraySlice sizes) { +Window MakeWindow(absl::Span sizes) { Window window; for (int64 size : sizes) { auto* dimension = window.add_dimensions(); @@ -36,7 +36,7 @@ Window MakeWindow(tensorflow::gtl::ArraySlice sizes) { return window; } -PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice sizes) { +PaddingConfig MakeSymmetricPadding(absl::Span sizes) { PaddingConfig config; for (int64 size : sizes) { auto* dimension = config.add_dimensions(); diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index ba473e2c8c3..ed93ecc30da 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -25,13 +25,13 @@ namespace window_util { // Creates a window with the given sizes in the dimensions and all strides set // to 1. -Window MakeWindow(tensorflow::gtl::ArraySlice sizes); +Window MakeWindow(absl::Span sizes); // Creates a padding config with symmetrical padding in each dimension, of value // given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero // pixels of padding in dimension 0, one pixel of padding symmetrically, on each // side of dimension 1, and two pixels of padding symmetrically on dimension 2. -PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice sizes); +PaddingConfig MakeSymmetricPadding(absl::Span sizes); string ToString(const WindowDimension& dim); string ToString(const Window& window);