From e924d67bff8c4fb58c8316d00b662f8d1e80eb95 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 20 Aug 2018 20:20:14 -0700 Subject: [PATCH] [XLA] Use absl::make_unique instead of xla::MakeUnique. Same for WrapUnique. PiperOrigin-RevId: 209531124 --- tensorflow/compiler/aot/BUILD | 2 + tensorflow/compiler/aot/codegen.cc | 3 +- .../compiler/aot/embedded_protocol_buffers.cc | 6 +- tensorflow/compiler/jit/BUILD | 6 +- .../compiler/jit/create_xla_launch_op.cc | 5 +- .../compiler/jit/create_xla_launch_op_test.cc | 9 +- tensorflow/compiler/jit/xla_device.cc | 5 +- tensorflow/compiler/jit/xla_launch_util.cc | 3 +- tensorflow/compiler/jit/xla_tensor.h | 3 +- tensorflow/compiler/tf2xla/BUILD | 7 +- .../compiler/tf2xla/functionalize_cond.cc | 5 +- .../tf2xla/functionalize_control_flow.cc | 2 +- .../compiler/tf2xla/functionalize_while.cc | 6 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 3 +- tensorflow/compiler/xla/BUILD | 16 +- tensorflow/compiler/xla/array2d.h | 4 +- tensorflow/compiler/xla/client/BUILD | 9 +- tensorflow/compiler/xla/client/client.cc | 12 +- .../compiler/xla/client/client_library.cc | 10 +- .../xla/client/compile_only_client.cc | 2 +- .../compiler/xla/client/local_client.cc | 8 +- tensorflow/compiler/xla/client/xla_builder.cc | 3 +- .../compiler/xla/client/xla_computation.cc | 4 +- tensorflow/compiler/xla/iterator_util_test.cc | 6 +- tensorflow/compiler/xla/literal.cc | 41 ++-- tensorflow/compiler/xla/literal.h | 8 +- tensorflow/compiler/xla/literal_test.cc | 13 +- tensorflow/compiler/xla/literal_util.cc | 20 +- tensorflow/compiler/xla/literal_util.h | 24 +-- .../compiler/xla/packed_literal_reader.cc | 4 +- tensorflow/compiler/xla/ptr_util.h | 35 ---- tensorflow/compiler/xla/python/BUILD | 1 + .../xla/python/local_computation_builder.cc | 2 +- tensorflow/compiler/xla/reference_util.cc | 51 ++--- tensorflow/compiler/xla/reference_util.h | 50 ++--- .../compiler/xla/reference_util_test.cc | 12 +- tensorflow/compiler/xla/service/BUILD | 55 +++++- .../xla/service/algebraic_simplifier.cc | 3 +- .../xla/service/algebraic_simplifier_test.cc | 2 +- .../xla/service/allocation_tracker.cc | 7 +- tensorflow/compiler/xla/service/backend.cc | 5 +- .../xla/service/batchnorm_expander_test.cc | 2 +- .../compiler/xla/service/buffer_assignment.cc | 20 +- .../xla/service/buffer_assignment_test.cc | 18 +- .../xla/service/buffer_liveness_test.cc | 74 ++++---- tensorflow/compiler/xla/service/call_graph.cc | 6 +- .../compiler/xla/service/call_inliner_test.cc | 2 +- .../compiler/xla/service/channel_tracker.cc | 2 +- .../xla/service/computation_placer.cc | 8 +- .../convolution_feature_group_converter.cc | 4 +- tensorflow/compiler/xla/service/cpu/BUILD | 6 + .../xla/service/cpu/compiler_functor.cc | 4 +- .../compiler/xla/service/cpu/cpu_compiler.cc | 30 +-- .../xla/service/cpu/cpu_runtime_test.cc | 8 +- .../xla/service/cpu/cpu_transfer_manager.cc | 5 +- .../service/cpu/parallel_task_assignment.cc | 3 +- .../xla/service/cpu/simple_orc_jit.cc | 2 +- .../compiler/xla/service/cpu/tests/BUILD | 2 + .../xla/service/cpu/tests/cpu_fusion_test.cc | 2 +- .../xla/service/cpu/tests/cpu_noalias_test.cc | 5 +- tensorflow/compiler/xla/service/executable.cc | 5 +- .../compiler/xla/service/execution_tracker.cc | 6 +- tensorflow/compiler/xla/service/gpu/BUILD | 16 +- .../xla/service/gpu/buffer_allocations.cc | 4 +- .../xla/service/gpu/conditional_thunk.cc | 2 +- .../compiler/xla/service/gpu/for_thunk.cc | 4 +- .../xla/service/gpu/gpu_executable.cc | 4 +- .../xla/service/gpu/gpu_transfer_manager.cc | 8 +- .../xla/service/gpu/hlo_execution_profiler.cc | 5 +- .../compiler/xla/service/gpu/hlo_schedule.cc | 10 +- .../xla/service/gpu/hlo_schedule_test.cc | 3 +- .../xla/service/gpu/infeed_manager.cc | 4 +- .../xla/service/gpu/ir_emitter_unnested.cc | 108 +++++------ .../compiler/xla/service/gpu/kernel_thunk.cc | 4 +- .../xla/service/gpu/llvm_gpu_backend/BUILD | 1 + .../gpu/llvm_gpu_backend/nvptx_backend_lib.cc | 4 +- .../xla/service/gpu/nvptx_compiler.cc | 8 +- .../xla/service/gpu/outfeed_manager.cc | 2 +- .../compiler/xla/service/gpu/pad_insertion.cc | 7 +- .../xla/service/gpu/partition_assignment.cc | 2 +- .../xla/service/gpu/stream_assignment.cc | 4 +- .../xla/service/gpu/stream_assignment_test.cc | 3 +- .../compiler/xla/service/gpu/tests/BUILD | 6 +- .../xla/service/gpu/tests/gpu_codegen_test.cc | 4 +- .../xla/service/gpu/tests/gpu_copy_test.cc | 2 +- .../xla/service/gpu/tests/gpu_index_test.cc | 2 +- .../xla/service/gpu/tests/gpu_ldg_test.cc | 2 +- .../xla/service/gpu/tests/gpu_noalias_test.cc | 2 +- .../compiler/xla/service/gpu/tuple_thunk.cc | 3 +- .../compiler/xla/service/gpu/while_thunk.cc | 6 +- .../compiler/xla/service/graphviz_example.cc | 4 +- .../compiler/xla/service/heap_simulator.cc | 12 +- .../xla/service/heap_simulator_test.cc | 25 ++- .../xla/service/hlo_alias_analysis.cc | 2 +- .../compiler/xla/service/hlo_computation.cc | 16 +- .../xla/service/hlo_constant_folding.cc | 3 +- .../xla/service/hlo_creation_utils.cc | 6 +- .../xla/service/hlo_creation_utils_test.cc | 2 +- .../compiler/xla/service/hlo_cse_test.cc | 2 +- .../xla/service/hlo_dataflow_analysis.cc | 4 +- .../compiler/xla/service/hlo_dce_test.cc | 2 +- .../compiler/xla/service/hlo_domain_map.cc | 11 +- .../compiler/xla/service/hlo_domain_test.cc | 11 +- .../compiler/xla/service/hlo_evaluator.cc | 91 +++++---- .../compiler/xla/service/hlo_evaluator.h | 4 +- .../xla/service/hlo_evaluator_test.cc | 39 ++-- .../xla/service/hlo_evaluator_typed_visitor.h | 29 +-- .../xla/service/hlo_execution_profile.cc | 3 +- .../compiler/xla/service/hlo_instruction.cc | 177 ++++++++++-------- .../compiler/xla/service/hlo_instruction.h | 3 +- .../compiler/xla/service/hlo_instructions.cc | 106 ++++++----- .../compiler/xla/service/hlo_instructions.h | 5 +- .../xla/service/hlo_liveness_analysis.cc | 4 +- tensorflow/compiler/xla/service/hlo_module.cc | 6 +- .../compiler/xla/service/hlo_module_config.cc | 2 +- .../xla/service/hlo_module_group_metadata.cc | 6 +- .../xla/service/hlo_module_group_util.cc | 4 +- .../compiler/xla/service/hlo_module_test.cc | 2 +- tensorflow/compiler/xla/service/hlo_parser.cc | 15 +- tensorflow/compiler/xla/service/hlo_parser.h | 2 +- .../compiler/xla/service/hlo_pass_pipeline.h | 2 +- tensorflow/compiler/xla/service/hlo_runner.cc | 8 +- .../xla/service/hlo_sharding_metadata.cc | 13 +- tensorflow/compiler/xla/service/hlo_value.cc | 2 +- .../compiler/xla/service/hlo_verifier.h | 5 +- .../compiler/xla/service/inliner_test.cc | 2 +- .../compiler/xla/service/interpreter/BUILD | 8 +- .../xla/service/interpreter/compiler.cc | 10 +- .../xla/service/interpreter/executable.cc | 2 +- .../interpreter_transfer_manager.cc | 4 +- .../xla/service/interpreter/platform.cc | 5 +- .../compiler/xla/service/layout_assignment.cc | 24 +-- .../compiler/xla/service/local_service.cc | 2 +- .../xla/service/logical_buffer_analysis.cc | 3 +- .../xla/service/reshape_mover_test.cc | 2 +- tensorflow/compiler/xla/service/service.cc | 13 +- .../compiler/xla/service/shaped_buffer.cc | 2 +- .../xla/service/shaped_buffer_test.cc | 3 +- .../compiler/xla/service/stream_pool.cc | 4 +- .../compiler/xla/service/transfer_manager.cc | 5 +- .../xla/service/tuple_points_to_analysis.cc | 3 +- tensorflow/compiler/xla/shape_tree.h | 2 +- tensorflow/compiler/xla/shape_tree_test.cc | 3 +- tensorflow/compiler/xla/tests/BUILD | 21 ++- .../compiler/xla/tests/broadcast_test.cc | 2 +- .../xla/tests/client_library_test_base.cc | 6 +- .../xla/tests/client_library_test_base.h | 4 +- .../convolution_dimension_numbers_test.cc | 6 +- .../compiler/xla/tests/convolution_test.cc | 14 +- tensorflow/compiler/xla/tests/copy_test.cc | 2 +- .../compiler/xla/tests/custom_call_test.cc | 2 +- tensorflow/compiler/xla/tests/fusion_test.cc | 2 +- .../compiler/xla/tests/hlo_test_base.cc | 6 +- .../xla/tests/hlo_verified_test_base.cc | 3 +- .../compiler/xla/tests/llvm_compiler_test.cc | 3 +- .../xla/tests/local_client_test_base.cc | 2 +- .../xla/tests/matrix_ops_simple_test.cc | 4 +- .../xla/tests/multioutput_fusion_test.cc | 2 +- tensorflow/compiler/xla/tests/pad_test.cc | 30 +-- .../compiler/xla/tests/reduce_window_test.cc | 5 +- tensorflow/compiler/xla/tests/test_utils.cc | 11 +- tensorflow/compiler/xla/tests/test_utils.h | 2 +- tensorflow/compiler/xla/tests/tuple_test.cc | 3 +- .../compiler/xla/text_literal_reader.cc | 4 +- 164 files changed, 978 insertions(+), 797 deletions(-) delete mode 100644 tensorflow/compiler/xla/ptr_util.h diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 1899a32e4dc..2220d0786d3 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -55,6 +55,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", ], ) @@ -193,6 +194,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@llvm//:core", "@llvm//:support", "@llvm//:target", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 89fefdad54f..a8485576ac1 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/tf2xla/str_util.h" @@ -617,7 +618,7 @@ Status GenerateMetadata(const CodegenOpts& opts, if (opts.gen_program_shape) { program_shape = - tensorflow::MakeUnique(compile_result.program_shape); + absl::make_unique(compile_result.program_shape); // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save // space. diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 4e27aafec77..8fb2fad31c6 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" @@ -27,7 +28,6 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "tensorflow/compiler/tf2xla/str_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/util.h" @@ -105,7 +105,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) { error.c_str()); } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( normalized_triple, /*CPU=*/"", /*Features=*/"", llvm::TargetOptions(), llvm::None)); } @@ -118,7 +118,7 @@ StatusOr CreateEmbeddedProtocolBuffers( llvm::LLVMContext llvm_context; std::unique_ptr module_with_serialized_proto = - MakeUnique("embedded_data_module", llvm_context); + absl::make_unique("embedded_data_module", llvm_context); EmbeddedProtocolBuffers result; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index e059f77563b..2c9adfe4f0d 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -128,11 +128,11 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -191,6 +191,7 @@ cc_library( "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", + "@com_google_absl//absl/memory", ], ) @@ -235,6 +236,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/memory", ], ) @@ -283,6 +285,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -303,6 +306,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index a2e6285339f..1b1ce78ed2b 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -223,8 +224,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - *kernel = MakeUnique(&construction, constant_arg_indices, - resource_arg_indices, function); + *kernel = absl::make_unique( + &construction, constant_arg_indices, resource_arg_indices, function); return s; } diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc index b75ab486b80..73866607621 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" @@ -65,11 +66,11 @@ class CreateXlaLaunchOpTest : public ::testing::Test { for (const auto& fdef : flib) { *(proto.add_function()) = fdef; } - lib_def_ = - MakeUnique(OpRegistry::Global(), proto); + lib_def_ = absl::make_unique( + OpRegistry::Global(), proto); OptimizerOptions opts; - device_mgr_ = MakeUnique(devices_); - pflr_ = MakeUnique( + device_mgr_ = absl::make_unique(devices_); + pflr_ = absl::make_unique( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 2a2691a6a40..70e6d0be0f2 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" @@ -101,7 +102,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - xla::MakeUnique(); + absl::make_unique(); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -327,7 +328,7 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { // to those methods; see the bug for details. Our only saving grace at the // moment is that this race doesn't seem to occur in practice. if (use_gpu_device_info_) { - auto gpu_device_info = MakeUnique(); + auto gpu_device_info = absl::make_unique(); gpu_device_info->stream = stream_.get(); gpu_device_info->default_context = device_context_; set_tensorflow_gpu_device_info(gpu_device_info.get()); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 4efbb2d5d7c..2ffce9298d9 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -175,7 +176,7 @@ void XlaComputationLaunchContext::PopulateInputs( << " not the same as on-host shape " << xla::ShapeUtil::HumanStringWithLayout(shape); se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_[i] = xla::MakeUnique( + arg_buffers_[i] = absl::make_unique( /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), client_->default_device_ordinal()); arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 8d36d0fa0a8..07a9bf0d4a7 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/core/framework/allocator.h" @@ -70,7 +71,7 @@ class XlaTensor { // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = - xla::MakeUnique(std::move(shaped_buffer)); + absl::make_unique(std::move(shaped_buffer)); } // Some tensors on the device may have known values on the host. We use these diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index c4fdaef940a..575917d078d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -211,6 +211,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -475,12 +476,12 @@ cc_library( "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -501,12 +502,12 @@ cc_library( "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -525,12 +526,12 @@ cc_library( "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index d24b5b1bbe3..0f5471616e1 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -399,7 +399,8 @@ Status Conditional::BuildArgumentNodes() { Status Conditional::ExtractBodies(Graph* graph) { VLOG(2) << "Extracting bodies for " << name(); for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { - bodies_[static_cast(b)] = xla::MakeUnique(graph->op_registry()); + bodies_[static_cast(b)] = + absl::make_unique(graph->op_registry()); } auto find_branch = [&](const Edge* e) { diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 2cfa3c046e1..188ada7255f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -21,13 +21,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/functionalize_while.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index fd3e3c6e30f..4fd134c6980 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -143,7 +143,7 @@ StatusOr BuildArgNode(Graph* graph, DataType type, int index) { Status BuildLoopCondition(const Graph& graph, Frame* frame, std::unique_ptr* cond_output) { VLOG(2) << "Building loop condition for " << frame->name; - *cond_output = xla::MakeUnique(graph.op_registry()); + *cond_output = absl::make_unique(graph.op_registry()); Graph* output = cond_output->get(); // Map from nodes in the original graph to the condition graph. @@ -180,7 +180,7 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, DataTypeVector* arg_types, std::unique_ptr* body_output) { VLOG(2) << "Building loop body for " << frame->name; - *body_output = xla::MakeUnique(graph.op_registry()); + *body_output = absl::make_unique(graph.op_registry()); Graph* output = body_output->get(); // Map from nodes in the original graph to the condition graph. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 226c89bcf1e..43ff5fcef89 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -310,7 +311,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // unique_ptr so we can capture the cleanup status in the end. xla_context->Ref(); Status status; - auto step_container = xla::MakeUnique( + auto step_container = absl::make_unique( step_id, [&status, device](const string& name) { status = device->resource_manager()->Cleanup(name); }); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index e36429f62d0..2cf77b71fb2 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -161,7 +161,6 @@ cc_library( "iterator_util.h", "map_util.h", "overflow_util.h", - "ptr_util.h", "util.h", ], visibility = ["//visibility:public"], @@ -172,8 +171,8 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -211,6 +210,7 @@ tf_cc_test( ":test", ":util", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -298,6 +298,7 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -316,6 +317,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -336,6 +338,7 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -406,8 +409,8 @@ cc_library( deps = [ ":array", ":types", - ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -490,6 +493,7 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -522,6 +526,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -577,10 +582,10 @@ cc_library( deps = [ ":shape_util", ":status_macros", - ":util", ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -594,6 +599,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -643,6 +649,7 @@ cc_library( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -661,6 +668,7 @@ tf_cc_test( "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index a17e81f4483..340f94fab72 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -24,8 +24,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -101,7 +101,7 @@ class Array2D : public Array { template std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 n1, int64 n2) { - auto array = MakeUnique>(n1, n2); + auto array = absl::make_unique>(n1, n2); int64 count = n1 * n2; NativeT step = static_cast((count > 1) ? (to - from) / (count - 1) : 0); diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 0ecf26e7723..6be44b1c390 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -71,12 +71,12 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -104,7 +104,6 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", @@ -117,6 +116,7 @@ cc_library( "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:support", ], ) @@ -130,11 +130,11 @@ cc_library( ":xla_computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:support", ], ) @@ -159,6 +159,7 @@ cc_library( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -186,6 +187,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", + "@com_google_absl//absl/memory", ], ) @@ -212,6 +214,7 @@ cc_library( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d0ce5e8a6af..25608d6616f 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" @@ -89,7 +89,7 @@ StatusOr> Client::TransferToServer( "TransferToServer request"); } - return MakeUnique(stub_, response.data()); + return absl::make_unique(stub_, response.data()); } Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, @@ -248,7 +248,7 @@ StatusOr> Client::Execute( } } - return MakeUnique(stub_, response.output()); + return absl::make_unique(stub_, response.output()); } StatusOr>> Client::ExecuteParallel( @@ -278,7 +278,7 @@ StatusOr>> Client::ExecuteParallel( std::vector> outputs; for (size_t i = 0; i < computations.size(); ++i) { outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); + absl::make_unique(stub_, response.responses(i).output())); if (computations[i].execution_profile != nullptr) { *computations[i].execution_profile = response.responses(i).profile(); } @@ -340,7 +340,7 @@ StatusOr>> Client::DeconstructTuple( std::vector> handles; for (auto& handle : response.element_handles()) { - handles.push_back(MakeUnique(stub_, handle)); + handles.push_back(absl::make_unique(stub_, handle)); } return std::move(handles); } @@ -369,7 +369,7 @@ StatusOr Client::GetComputationStats( StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); - return MakeUnique(result); + return absl::make_unique(result); } StatusOr Client::GetShape(const GlobalData& data) { diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 803a9e40094..27b7fa7b292 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -94,10 +95,10 @@ ClientLibrary::~ClientLibrary() = default; service_options.set_intra_op_parallelism_threads( options.intra_op_parallelism_threads()); - auto instance = MakeUnique(); + auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); - instance->client = MakeUnique(instance->service.get()); + instance->client = absl::make_unique(instance->service.get()); LocalClient* cl = instance->client.get(); client_library.local_instances_.insert( @@ -134,10 +135,11 @@ ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) { return it->second->client.get(); } - auto instance = MakeUnique(); + auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, CompileOnlyService::NewService(platform)); - instance->client = MakeUnique(instance->service.get()); + instance->client = + absl::make_unique(instance->service.get()); CompileOnlyClient* cl = instance->client.get(); client_library.compile_only_instances_.insert( diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 5c9abad4c31..b6012a03520 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index cffb24e29be..1cd3e9b22f9 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" @@ -257,9 +257,9 @@ StatusOr> LocalClient::Compile( TF_ASSIGN_OR_RETURN(std::unique_ptr executable, local_service_->CompileExecutable( computation, argument_layouts, updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); + return absl::WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + updated_options)); } StatusOr LocalClient::LiteralToShapedBuffer( diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index e65dd5cbb4a..54fe87a7a8d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -2297,7 +2298,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( std::unique_ptr XlaBuilder::CreateSubBuilder( const string& computation_name) { - auto sub_builder = MakeUnique(computation_name); + auto sub_builder = absl::make_unique(computation_name); sub_builder->parent_builder_ = this; sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_; return sub_builder; diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index 3543d41fc26..22c9e83bb2a 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -32,7 +32,7 @@ StatusOr> XlaComputation::Snapshot() const { if (IsNull()) { return InvalidArgument("Computation is invalid."); } - auto session = MakeUnique(); + auto session = absl::make_unique(); *session->mutable_hlo()->mutable_hlo_module() = proto_; return std::move(session); } diff --git a/tensorflow/compiler/xla/iterator_util_test.cc b/tensorflow/compiler/xla/iterator_util_test.cc index 7bc3189507e..ec8b66df2db 100644 --- a/tensorflow/compiler/xla/iterator_util_test.cc +++ b/tensorflow/compiler/xla/iterator_util_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/test.h" namespace xla { @@ -27,7 +27,7 @@ namespace { TEST(UnwrappingIteratorTest, Simple) { std::vector> v; for (int i = 0; i < 3; ++i) { - v.push_back(MakeUnique(i)); + v.push_back(absl::make_unique(i)); } int i = 0; for (auto iter = MakeUnwrappingIterator(v.begin()); @@ -51,7 +51,7 @@ TEST(UnwrappingIteratorTest, PostincrementOperator) { TEST(UnwrappingIteratorTest, StdFind) { std::list> l; for (int i = 0; i < 3; ++i) { - l.push_back(MakeUnique(i)); + l.push_back(absl::make_unique(i)); } EXPECT_EQ(l.begin()->get(), *std::find(MakeUnwrappingIterator(l.begin()), diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 36e472568ec..d54f051a1a9 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -134,7 +135,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { Literal::Literal(const Shape& shape, bool allocate_arrays) : MutableLiteralBase() { - shape_ = MakeUnique(shape); + shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); root_piece_->set_subshape(shape_.get()); @@ -175,7 +176,7 @@ Literal& Literal::operator=(Literal&& other) { } std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); literal->root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { @@ -289,7 +290,7 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { return InvalidArgument("LiteralProto has no layout"); } - auto literal = MakeUnique(proto.shape()); + auto literal = absl::make_unique(proto.shape()); TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -479,7 +480,7 @@ Status Literal::MoveFrom(Literal&& src_literal, dest_piece.set_sparse_indices(src_piece.sparse_indices()); }); - src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + src_literal.shape_ = absl::make_unique(ShapeUtil::MakeNil()); delete src_literal.root_piece_; src_literal.root_piece_ = new LiteralBase::Piece(); src_literal.root_piece_->set_subshape(src_literal.shape_.get()); @@ -566,7 +567,7 @@ std::unique_ptr LiteralBase::Relayout( Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = MakeUnique(new_shape); + auto result = absl::make_unique(new_shape); TF_CHECK_OK(result->CopyFrom(*this)); return result; } @@ -602,7 +603,7 @@ StatusOr> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr result = MakeUnique(result_shape); + std::unique_ptr result = absl::make_unique(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in @@ -691,7 +692,7 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = MakeUnique(permuted_shape); + auto new_literal = absl::make_unique(permuted_shape); DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); @@ -702,7 +703,7 @@ template std::unique_ptr LiteralBase::SliceInternal( const Shape& result_shape, tensorflow::gtl::ArraySlice start_indices) const { - auto result_literal = MakeUnique(result_shape); + auto result_literal = absl::make_unique(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { @@ -756,7 +757,7 @@ Literal LiteralBase::Clone() const { } std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = MakeUnique(shape()); + auto result = absl::make_unique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } @@ -1203,7 +1204,7 @@ template std::unique_ptr ConvertBetweenNativeTypesWithConverter( const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + auto result_literal = absl::make_unique(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); auto src_data = src_literal.data(); @@ -1249,7 +1250,7 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { template std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique( + auto result_literal = absl::make_unique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; @@ -1396,7 +1397,7 @@ StatusOr> LiteralBase::ConvertToShape( element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); elements.push_back(std::move(*new_element)); } - auto converted = MakeUnique(); + auto converted = absl::make_unique(); *converted = MutableLiteralBase::MoveIntoTuple(&elements); return std::move(converted); } @@ -1956,7 +1957,7 @@ MutableLiteralBase::~MutableLiteralBase() {} MutableBorrowingLiteral::MutableBorrowingLiteral( const MutableBorrowingLiteral& literal) : MutableLiteralBase() { - shape_ = MakeUnique(literal.shape()); + shape_ = absl::make_unique(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1967,7 +1968,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( const MutableBorrowingLiteral& literal) { - shape_ = MakeUnique(literal.shape()); + shape_ = absl::make_unique(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1981,7 +1982,7 @@ MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( MutableBorrowingLiteral::MutableBorrowingLiteral( const MutableLiteralBase& literal) : MutableLiteralBase() { - shape_ = MakeUnique(literal.shape()); + shape_ = absl::make_unique(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1992,7 +1993,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) : MutableLiteralBase() { - shape_ = MakeUnique(literal->shape()); + shape_ = absl::make_unique(literal->shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -2004,7 +2005,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral literal, const ShapeIndex& view_root) : MutableLiteralBase() { - shape_ = MakeUnique(literal.piece(view_root).subshape()); + shape_ = absl::make_unique(literal.piece(view_root).subshape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -2016,7 +2017,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : MutableLiteralBase() { - shape_ = MakeUnique(shape); + shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); CHECK(!ShapeUtil::IsTuple(*shape_)); @@ -2061,7 +2062,7 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { } BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { + : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsArray(*shape_)); CHECK(LayoutUtil::HasLayout(*shape_)); @@ -2072,7 +2073,7 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) BorrowingLiteral::BorrowingLiteral( tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { + : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsTuple(*shape_)); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 92c0f903cbe..ed9de652994 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -25,13 +25,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -312,7 +312,7 @@ class LiteralBase { // Note: It's an antipattern to use this method then immediately call // MutableLiteralBase::Populate on the result (since that results in zero // initialization, then reinitialization. Conside if a call to - // MakeUnique(shape), followed by the call to + // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. static std::unique_ptr CreateFromShape(const Shape& shape); @@ -1154,8 +1154,8 @@ std::unique_ptr LiteralBase::Replicate(int64 times) const { for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = - MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); + auto literal = absl::make_unique( + ShapeUtil::MakeShape(shape().element_type(), bounds)); int64 elements = ShapeUtil::ElementsIn(literal->shape()); if (elements == 0) { return literal; diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index e8f919950f0..c5d0c2c267e 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -355,15 +356,15 @@ TEST_F(LiteralUtilTest, TokenEquality) { TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = - MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); + auto colmajor = absl::make_unique( + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); colmajor->Set({0, 0}, 1.0); colmajor->Set({0, 1}, 2.0); colmajor->Set({1, 0}, 3.0); colmajor->Set({1, 1}, 4.0); - auto rowmajor = - MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); + auto rowmajor = absl::make_unique( + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); rowmajor->Set({0, 0}, 1.0); rowmajor->Set({0, 1}, 2.0); rowmajor->Set({1, 0}, 3.0); @@ -1089,7 +1090,7 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1131,7 +1132,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 5d33df7d40b..d4c7b76b281 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -57,7 +58,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()); } }); - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); // Then copy over the data from 'literal' converting FromNativeT values to // ToNativeT values as necessary. @@ -102,7 +103,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::CreateToken() { - return MakeUnique(ShapeUtil::MakeTokenShape()); + return absl::make_unique(ShapeUtil::MakeTokenShape()); } /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { @@ -279,7 +280,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ std::unique_ptr LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = MakeUnique( + auto literal = absl::make_unique( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); literal->PopulateR1(values); return literal; @@ -287,7 +288,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ std::unique_ptr LiteralUtil::CreateR1U8( tensorflow::StringPiece value) { - auto literal = MakeUnique( + auto literal = absl::make_unique( ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { literal->Set({i}, value[i]); @@ -312,7 +313,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = MakeUnique( + auto new_literal = absl::make_unique( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used @@ -436,7 +437,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { for (const auto* element : elements) { element_shapes.push_back(element->shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } @@ -449,7 +451,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { for (const auto& element : elements) { element_shapes.push_back(element.shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); } @@ -463,7 +466,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { for (const auto& element : elements) { element_shapes.push_back(element->shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int64 i = 0; i < elements.size(); ++i) { TF_CHECK_OK( literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e3737a9d005..1109021ea89 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -327,7 +327,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal); template /* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { - auto literal = MakeUnique(ShapeUtil::MakeShape( + auto literal = absl::make_unique(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); literal->Set({}, value); return literal; @@ -336,7 +336,7 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateR1( tensorflow::gtl::ArraySlice values) { - auto literal = MakeUnique( + auto literal = absl::make_unique( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); literal->PopulateR1(values); @@ -347,7 +347,7 @@ template /* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, @@ -433,9 +433,10 @@ template int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = MakeUnique(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType(), dimensions, - indices.max_indices())); + auto literal = + absl::make_unique(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); literal->PopulateSparse(indices, values, sort); return literal; } @@ -451,7 +452,7 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); literal->PopulateFromArray(values); @@ -571,8 +572,9 @@ template /* static */ std::unique_ptr LiteralUtil::CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( - primitive_util::NativeToPrimitiveType(), dimensions)); + auto literal = + absl::make_unique(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); literal->PopulateWithValue(value); return literal; } @@ -584,7 +586,7 @@ LiteralUtil::CreateRandomLiteral( const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); TF_RETURN_IF_ERROR(literal.get()->Populate( [&](tensorflow::gtl::ArraySlice indexes) { return generator(indexes); diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 6b7fd10d63f..55c4a80e29b 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -57,7 +57,7 @@ StatusOr> PackedLiteralReader::Read( PrimitiveType_Name(shape.element_type()).c_str()); } - auto result = MakeUnique(literal_shape); + auto result = absl::make_unique(literal_shape); result->PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h deleted file mode 100644 index bfcdfc62f95..00000000000 --- a/tensorflow/compiler/xla/ptr_util.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ - -// As this was moved to tensorflow/core/util, provide indirections here to -// maintain current functionality of the library. - -#include - -#include -#include -#include - -#include "tensorflow/core/util/ptr_util.h" - -namespace xla { -using tensorflow::MakeUnique; -using tensorflow::WrapUnique; -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index c8f2d65c223..a91336c3ac9 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -59,6 +59,7 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 212439dec8c..c133a204197 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/thread_annotations.h" diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a8035208769..3de7ee2bc8c 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -43,7 +44,7 @@ std::unique_ptr> MatmulArray2DImpl( int m = lhs.height(); int n = rhs.width(); int k = lhs.width(); - auto result = MakeUnique>(m, n); + auto result = absl::make_unique>(m, n); // Because Eigen is a header-oriented library, make sure that the Eigen code // is the same as the code used by the CPU backend (otherwise the linker will // randomly pick *some* definition). @@ -77,7 +78,8 @@ std::unique_ptr> MatmulArray2DImpl( /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( const Array2D& input) { - auto result = MakeUnique>(input.height(), input.width()); + auto result = + absl::make_unique>(input.height(), input.width()); for (int64 rowno = 0; rowno < input.height(); ++rowno) { for (int64 colno = 0; colno < input.height(); ++colno) { (*result)(rowno, colno) = input(rowno, colno); @@ -126,8 +128,8 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, {rhs_dilation, 1}, dnums2d); - auto convr3 = MakeUnique>(convr4->planes(), convr4->depth(), - convr4->height()); + auto convr3 = absl::make_unique>( + convr4->planes(), convr4->depth(), convr4->height()); convr4->Each( [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { CHECK_EQ(indices[3], 0); @@ -201,7 +203,7 @@ ReferenceUtil::ReduceWindow1DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0]); + auto result = absl::make_unique>(window_counts[0]); // Do a full 1D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { @@ -247,7 +249,8 @@ ReferenceUtil::ReduceWindow2DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1]); + auto result = + absl::make_unique>(window_counts[0], window_counts[1]); // Do a full 2D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { @@ -296,8 +299,8 @@ ReferenceUtil::ReduceWindow2DGeneric( WindowCount(dim_lengths[i], window[i], stride[i], padding); pad_low[i] = padding_both[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1], - window_counts[2]); + auto result = absl::make_unique>( + window_counts[0], window_counts[1], window_counts[2]); for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { @@ -358,8 +361,8 @@ ReferenceUtil::ReduceWindow4DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1], - window_counts[2], window_counts[3]); + auto result = absl::make_unique>( + window_counts[0], window_counts[1], window_counts[2], window_counts[3]); // Do a full 4D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { @@ -426,8 +429,8 @@ ReferenceUtil::SelectAndScatter4DGePlus( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; - auto result = MakeUnique>(operand.n1(), operand.n2(), - operand.n3(), operand.n4()); + auto result = absl::make_unique>(operand.n1(), operand.n2(), + operand.n3(), operand.n4()); std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -583,10 +586,10 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); auto result = - MakeUnique>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique>(result_literal->shape().dimensions(0), + result_literal->shape().dimensions(1), + result_literal->shape().dimensions(2), + result_literal->shape().dimensions(3)); result->Each([&](tensorflow::gtl::ArraySlice indices, float* value) { *value = result_literal->Get(indices); @@ -601,7 +604,7 @@ ReferenceUtil::ReduceToColArray2D( const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(); + auto result = absl::make_unique>(); for (int64 i = 0; i < rows; ++i) { float acc = init; for (int64 j = 0; j < cols; ++j) { @@ -618,7 +621,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(); + auto result = absl::make_unique>(); for (int64 i = 0; i < cols; ++i) { float acc = init; for (int64 j = 0; j < rows; ++j) { @@ -674,8 +677,8 @@ ReferenceUtil::ReduceToRowArray2D( /* static */ std::unique_ptr> ReferenceUtil::Broadcast1DTo4D( const std::vector& array, const std::vector& bounds, int64 broadcast_from_dim) { - auto result = - MakeUnique>(bounds[0], bounds[1], bounds[2], bounds[3]); + auto result = absl::make_unique>(bounds[0], bounds[1], + bounds[2], bounds[3]); for (int64 i = 0; i < result->n1(); ++i) { for (int64 j = 0; j < result->n2(); ++j) { for (int64 k = 0; k < result->n3(); ++k) { @@ -710,7 +713,7 @@ ReferenceUtil::ReduceToRowArray2D( CHECK_EQ(dims.size(), 1); int64 rows = dims[0] == 0 ? array.n2() : array.n1(); int64 cols = dims[0] == 2 ? array.n2() : array.n3(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); result->Fill(init); for (int i0 = 0; i0 < array.n1(); ++i0) { for (int i1 = 0; i1 < array.n2(); ++i1) { @@ -730,7 +733,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& map_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(matrix(i, j)); @@ -746,7 +749,7 @@ ReferenceUtil::ReduceToRowArray2D( CHECK_EQ(lhs.width(), rhs.width()); int64 rows = lhs.height(); int64 cols = rhs.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); @@ -760,7 +763,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& map_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(matrix(i, j), i, j); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 8fa6961d197..88f853a3591 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -22,11 +22,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -42,7 +42,8 @@ class ReferenceUtil { template static std::unique_ptr> TransposeArray2D( const Array2D& operand) { - auto result = MakeUnique>(operand.width(), operand.height()); + auto result = + absl::make_unique>(operand.width(), operand.height()); for (int64 w = 0; w < operand.width(); ++w) { for (int64 h = 0; h < operand.height(); ++h) { (*result)(w, h) = operand(h, w); @@ -242,7 +243,7 @@ class ReferenceUtil { const Array2D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 2); - auto result = MakeUnique>( + auto result = absl::make_unique>( concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(), concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2()); for (int64 i0 = 0; i0 < result->n1(); ++i0) { @@ -276,7 +277,8 @@ class ReferenceUtil { out_dims[i] = lhs_dims[i] + rhs_dims[i]; } } - auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2]); + auto result = + absl::make_unique>(out_dims[0], out_dims[1], out_dims[2]); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -310,8 +312,8 @@ class ReferenceUtil { out_dims[i] = lhs_dims[i] + rhs_dims[i]; } } - auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2], - out_dims[3]); + auto result = absl::make_unique>(out_dims[0], out_dims[1], + out_dims[2], out_dims[3]); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -355,9 +357,9 @@ class ReferenceUtil { CHECK_LE(limits[1], input.n2()); CHECK_GE(strides[0], 1); CHECK_GE(strides[1], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { (*result)(i0, i1) = @@ -381,10 +383,10 @@ class ReferenceUtil { CHECK_GE(strides[0], 1); CHECK_GE(strides[1], 1); CHECK_GE(strides[2], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1]), - CeilOfRatio(limits[2] - starts[2], strides[2])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { @@ -415,11 +417,11 @@ class ReferenceUtil { CHECK_GE(strides[1], 1); CHECK_GE(strides[2], 1); CHECK_GE(strides[3], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1]), - CeilOfRatio(limits[2] - starts[2], strides[2]), - CeilOfRatio(limits[3] - starts[3], strides[3])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2]), + CeilOfRatio(limits[3] - starts[3], strides[3])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -460,8 +462,8 @@ class ReferenceUtil { template static std::unique_ptr> MapWithIndexArray4D( const Array4D& input, F&& map_function) { - auto result = MakeUnique>(input.planes(), input.depth(), - input.height(), input.width()); + auto result = absl::make_unique>( + input.planes(), input.depth(), input.height(), input.width()); for (int64 plane = 0; plane < input.planes(); ++plane) { for (int64 depth = 0; depth < input.depth(); ++depth) { for (int64 height = 0; height < input.height(); ++height) { @@ -495,8 +497,8 @@ class ReferenceUtil { template static std::unique_ptr> MapWithIndexArray4D( const Array4D& lhs, const Array4D& rhs, F&& map_function) { - auto result = MakeUnique>(lhs.planes(), lhs.depth(), - lhs.height(), lhs.width()); + auto result = absl::make_unique>(lhs.planes(), lhs.depth(), + lhs.height(), lhs.width()); for (int64 plane = 0; plane < lhs.planes(); ++plane) { for (int64 depth = 0; depth < lhs.depth(); ++depth) { for (int64 height = 0; height < lhs.height(); ++height) { @@ -530,7 +532,7 @@ class ReferenceUtil { int64 out1 = in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; - auto result = MakeUnique>(out0, out1); + auto result = absl::make_unique>(out0, out1); result->Fill(pad); int64 o0 = low_padding0; for (int64 i0 = 0; i0 < in0; ++i0) { @@ -669,7 +671,7 @@ class ReferenceUtil { static std::unique_ptr> ApplyElementwise2D( F&& f, const Array2D& array1, const Array2D&... arrays) { AssertSameSize2D(array1, arrays...); - auto result = MakeUnique>(array1.n1(), array1.n2()); + auto result = absl::make_unique>(array1.n1(), array1.n2()); for (int64 i = 0; i < array1.n1(); ++i) { for (int64 j = 0; j < array1.n2(); ++j) { (*result)(i, j) = f(array1(i, j), arrays(i, j)...); diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 8091bed4996..3ec01921484 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -36,7 +36,7 @@ namespace { class ReferenceUtilTest : public ::testing::Test { protected: ReferenceUtilTest() { - matrix_ = MakeUnique>(rows_, cols_); + matrix_ = absl::make_unique>(rows_, cols_); // [1.f 2.f 3.f] // [4.f 5.f 6.f] for (int64 i = 0; i < rows_; ++i) { @@ -112,8 +112,8 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { } TEST_F(ReferenceUtilTest, MapArray4D) { - auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, - /*height=*/4, /*width=*/5); + auto input = absl::make_unique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); @@ -126,8 +126,8 @@ TEST_F(ReferenceUtilTest, MapArray4D) { } TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { - auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, - /*height=*/4, /*width=*/5); + auto input = absl::make_unique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); input->FillWithMultiples(1.0f); auto subtract_index = [](float value, int64 plane, int64 depth, int64 height, int64 width) { diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 12ec38736ec..01f273ad1f7 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -239,6 +239,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -265,6 +266,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -314,6 +316,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -452,6 +455,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -520,6 +524,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -577,6 +582,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -618,6 +624,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -650,6 +657,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -722,6 +730,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -739,6 +748,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -769,6 +779,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], ) @@ -816,6 +827,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -834,6 +846,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -850,6 +863,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -867,6 +881,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -926,6 +941,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -953,6 +969,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -980,6 +997,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1034,6 +1052,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1052,6 +1071,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1068,6 +1088,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1085,6 +1106,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1186,6 +1208,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1203,6 +1226,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1274,6 +1298,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1297,6 +1322,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1320,6 +1346,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1421,6 +1448,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1610,6 +1638,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1643,6 +1672,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -1662,6 +1692,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform computation placer registration ) @@ -1753,6 +1784,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1798,6 +1830,7 @@ tf_cc_binary( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1814,6 +1847,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1873,6 +1907,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1891,6 +1926,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1932,6 +1968,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2025,6 +2062,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -2037,7 +2075,6 @@ cc_library( ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -2045,6 +2082,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2095,6 +2133,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2185,6 +2224,7 @@ cc_library( ":shape_inference", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2267,6 +2307,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2348,6 +2389,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2385,6 +2427,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2401,6 +2444,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2432,6 +2476,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2446,6 +2491,7 @@ cc_library( "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2506,6 +2552,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2606,10 +2653,10 @@ cc_library( ":computation_layout", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2790,9 +2837,9 @@ cc_library( hdrs = ["stream_pool.h"], deps = [ "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -2890,6 +2937,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -3085,6 +3133,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 2c539eb99aa..1d26e306519 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -541,7 +542,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = MakeUnique( + std::unique_ptr unique_scalar = absl::make_unique( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index d3785006d59..427069af5f4 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 51ebc4763b6..d0806d24a22 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -91,8 +91,9 @@ StatusOr AllocationTracker::RegisterInternal( // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer // into a regular ShapedBuffer, which is stored in // handle_to_shaped_buffers_. - handle_to_shaped_buffers_[handle].emplace_back(MakeUnique( - ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); + handle_to_shaped_buffers_[handle].emplace_back( + absl::make_unique( + ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); } GlobalDataHandle result; diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index d12be3e007f..841d0fa85bb 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -127,8 +128,8 @@ Backend::Backend( } } // Create a memory allocator for the valid stream executors. - memory_allocator_ = - MakeUnique(platform, stream_executors); + memory_allocator_ = absl::make_unique( + platform, stream_executors); CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index a7253514628..f62ab12319b 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index cfd26fc778c..cc15c7122fc 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,8 +22,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -1100,8 +1100,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment)), + HeapSimulator::Run(absl::make_unique( + absl::make_unique(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); @@ -1130,11 +1130,12 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment)), - *computation, *instruction_sequence, - assignment->points_to_analysis(), - assignment->buffer_size_, options)); + HeapSimulator::Run( + absl::make_unique( + absl::make_unique(alignment)), + *computation, *instruction_sequence, + assignment->points_to_analysis(), assignment->buffer_size_, + options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1646,7 +1647,8 @@ StatusOr> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - // Can't use MakeUnique because BufferAssignment constructor is private. + // Can't use absl::make_unique because BufferAssignment constructor is + // private. std::unique_ptr assignment( new BufferAssignment(module, std::move(liveness), std::move(buffer_size), std::move(color_alignment))); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index eccb146a0d7..52abda16c4e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" @@ -87,7 +87,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -98,7 +98,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignmentNoBuffersForConstants( HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -109,7 +109,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -127,7 +127,8 @@ class BufferAssignmentTest : public HloTestBase { instruction_sequence.end()); return BufferAssigner::Run( module, - xla::MakeUnique(module, module_sequence), + absl::make_unique(module, + module_sequence), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1769,7 +1770,8 @@ class WhileBufferAssignmentTest : public HloTestBase { auto sequence = ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, xla::MakeUnique(module, sequence), + module, + absl::make_unique(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -2083,7 +2085,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto assignment, BufferAssigner::Run( module.get(), - xla::MakeUnique(module.get(), sequence), + absl::make_unique(module.get(), sequence), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return 1; }, /*allow_input_output_aliasing=*/false, @@ -2340,7 +2342,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto assignment = BufferAssigner::Run( module.get(), - xla::MakeUnique(module.get(), sequence), + absl::make_unique(module.get(), sequence), ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true) diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 4a927b57674..3ffb7de65fb 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -119,8 +119,8 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -167,10 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -215,8 +215,8 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -249,8 +249,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -293,10 +293,10 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { SequentialHloOrdering::HloModuleSequence module_sequence; std::vector order = {param, negate, exp, add}; module_sequence.emplace(computation, order); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -342,10 +342,10 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { std::vector order = {param, add, recv, recv_done, send, send_done}; module_sequence.emplace(computation, order); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. @@ -376,8 +376,8 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -412,8 +412,8 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -453,8 +453,8 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -518,8 +518,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -580,8 +580,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -668,10 +668,10 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); @@ -780,10 +780,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 985ff30e80a..d6efef5f12f 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -237,8 +237,8 @@ void CallGraph::SetCallContexts() { /* static */ std::unique_ptr CallGraph::Build(const HloModule* module) { - // Constructor for CallGraph is private so MakeUnique can't be used. - auto call_graph = WrapUnique(new CallGraph(module)); + // Constructor for CallGraph is private so absl::make_unique can't be used. + auto call_graph = absl::WrapUnique(new CallGraph(module)); VLOG(2) << "Building call graph for:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index ff968bca297..e75f6f146d7 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 13008efed14..9c9e373821d 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 187ce568cbb..afbbea35b89 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -60,8 +60,8 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { "computation_count=%d", proto.replica_count(), proto.computation_count()); } - auto assignment = MakeUnique(proto.replica_count(), - proto.computation_count()); + auto assignment = absl::make_unique( + proto.replica_count(), proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); ++computation) { const auto& computation_device = proto.computation_devices(computation); @@ -156,7 +156,7 @@ ComputationPlacer::GetPlatformComputationPlacers() { } // namespace xla static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 45252fc1eee..8affa08b652 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -214,7 +214,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { expanded_filter = add(HloInstruction::CreateConcatenate( expanded_filter_shape, concat_operands, input_feature_dim)); } - auto zero = add(HloInstruction::CreateConstant(MakeUnique( + auto zero = add(HloInstruction::CreateConstant(absl::make_unique( LiteralUtil::Zero(expanded_filter_shape.element_type())))); auto zero_filter = add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 9cad6749345..850948b54b8 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -85,6 +86,7 @@ cc_library( ":ir_emitter", ":parallel_task_assignment", ":simple_orc_jit", + "@com_google_absl//absl/memory", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", @@ -178,6 +180,7 @@ cc_library( ":runtime_single_threaded_conv2d", ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", + "@com_google_absl//absl/memory", "@llvm//:execution_engine", "@llvm//:core", "@llvm//:mc", # fixdeps: keep @@ -418,6 +421,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@llvm//:analysis", "@llvm//:core", "@llvm//:ipo", @@ -634,6 +638,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -810,6 +815,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 128eea4828b..73b03440cbb 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -35,7 +36,6 @@ limitations under the License. #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -205,7 +205,7 @@ void CompilerFunctor::AddTargetInfoPasses( llvm::legacy::PassManagerBase* passes) const { llvm::Triple target_triple(target_machine_->getTargetTriple()); auto target_library_info_impl = - MakeUnique(target_triple); + absl::make_unique(target_triple); target_library_info_impl->addVectorizableFunctions( VectorFunctionsForTargetLibraryInfoImpl()); passes->add( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index fde8fbd4862..5116f926f50 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -26,6 +26,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" // IWYU pragma: no_include "llvm/Config/Targets.def.inc" +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" @@ -453,7 +453,7 @@ Status CreateHloProfilingArtifacts( computation_to_profile_idx, std::unique_ptr* hlo_profile_index_map, std::unique_ptr* hlo_profile_printer_data) { - *hlo_profile_index_map = MakeUnique(module); + *hlo_profile_index_map = absl::make_unique(module); const HloComputation& entry_computation = *module.entry_computation(); TF_ASSIGN_OR_RETURN( @@ -520,11 +520,11 @@ StatusOr> CpuCompiler::RunBackend( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = xla::MakeUnique(); + auto llvm_context = absl::make_unique(); auto llvm_module = - xla::MakeUnique("__compute_module", *llvm_context); + absl::make_unique("__compute_module", *llvm_context); - auto jit = xla::MakeUnique( + auto jit = absl::make_unique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -566,12 +566,12 @@ StatusOr> CpuCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module.get(), + absl::make_unique( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -716,7 +716,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); llvm::StringRef features = llvm_ir::AsStringRef(options.features()); llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); - std::unique_ptr target_machine = WrapUnique( + std::unique_ptr target_machine = absl::WrapUnique( target->createTargetMachine(triple.getTriple(), cpu_name, features, CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None, opt_level)); @@ -757,7 +757,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::unique_ptr assignment, BufferAssigner::Run( module, - xla::MakeUnique(module, module_sequence), + absl::make_unique(module, module_sequence), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -851,7 +851,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment->GetUniqueTopLevelOutputSlice()); - results.emplace_back(MakeUnique( + results.emplace_back(absl::make_unique( std::move(object_file_data), std::move(buffer_infos), result_slice.index(), std::move(hlo_profile_printer_data))); } @@ -874,7 +874,7 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::host::kHostPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 2ac950e6d93..bc4cfc09996 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -46,7 +46,7 @@ std::unique_ptr> MaybeTransposeArray2D(const Array2D& array, if (transpose) { std::swap(output_width, output_height); } - auto output = MakeUnique>(output_height, output_width); + auto output = absl::make_unique>(output_height, output_width); for (int y = 0; y < array.height(); y++) { for (int x = 0; x < array.width(); x++) { if (transpose) { @@ -93,7 +93,7 @@ std::unique_ptr> EigenMatrixMultiply(const Array2D& a, // Since we're going to transpose c before returning it. Swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); + auto c_transpose = absl::make_unique>(n, m); if (single_threaded) { __xla_cpu_runtime_EigenSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), @@ -204,7 +204,7 @@ std::unique_ptr> MKLMatrixMultiply(const Array2D& a, // Since we're going to transpose c before returning it, swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); + auto c_transpose = absl::make_unique>(n, m); if (single_threaded) { __xla_cpu_runtime_MKLSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 59bc7e0e16f..b07cd675ffc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" @@ -256,7 +257,7 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( VLOG(2) << "Enqueueing outfeed buffer (for the device to populate) of length " << size_32 << "B"; - buffers.emplace_back(MakeUnique(b.first, size_32)); + buffers.emplace_back(absl::make_unique(b.first, size_32)); } std::vector buffer_pointers; @@ -283,7 +284,7 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( } // namespace xla static std::unique_ptr CreateCpuTransferManager() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4fa5984b046..286d407ca6e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" @@ -109,7 +110,7 @@ ParallelTaskAssignment::ParallelTaskAssignment( : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. - auto cost_analysis = MakeUnique(shape_size); + auto cost_analysis = absl::make_unique(shape_size); HloComputation* computation = module->entry_computation(); Status status = computation->root_instruction()->Accept(cost_analysis.get()); if (status.ok()) { diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index be772cfb7e5..b026aef3fec 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Host.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 181cec3cddd..4635fa5d74f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -51,6 +51,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +95,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index d98856fdbf4..b68ac67574d 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 01daed4bcd3..bb105194f1c 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -62,7 +62,8 @@ TEST_F(CpuNoAliasTest, Concat) { // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. auto status_or_buffer_assn = BufferAssigner::Run( - hlo_module.get(), MakeUnique(hlo_module.get()), + hlo_module.get(), + absl::make_unique(hlo_module.get()), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return /*alignment=*/1; }); ASSERT_EQ(status_or_buffer_assn.status(), Status::OK()); diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index fd75847d0c0..1c9f396b68f 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" @@ -76,8 +77,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? MakeUnique(&hlo_profile_printer_data(), - &hlo_profile_index_map()) + ? absl::make_unique(&hlo_profile_printer_data(), + &hlo_profile_index_map()) : nullptr; StatusOr return_value = diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 228c3fac95c..70a78c8a2b6 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -53,8 +53,8 @@ ExecutionHandle ExecutionTracker::Register(Backend* backend, tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; auto inserted = handle_to_execution_.emplace( - handle, - MakeUnique(backend, std::move(streams), profile, result)); + handle, absl::make_unique(backend, std::move(streams), + profile, result)); CHECK(inserted.second); ExecutionHandle execution_handle; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index fd1e34a5477..17eefc430d2 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -91,6 +92,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -107,6 +109,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -181,6 +184,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", "@llvm//:core", "@llvm//:support", ], @@ -244,6 +248,7 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -258,6 +263,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -338,6 +344,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], ) @@ -547,6 +554,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", + "@com_google_absl//absl/memory", ], ) @@ -603,6 +611,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:core", ], alwayslink = True, # Contains per-platform transfer manager registration @@ -673,6 +682,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:core", ], alwayslink = True, # Contains compiler registration @@ -705,8 +715,8 @@ cc_library( ":xfeed_queue", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -721,6 +731,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -770,12 +781,12 @@ cc_library( ":stream_assignment", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/compiler/xla/service:hlo_scheduling", + "@com_google_absl//absl/memory", ], ) @@ -792,6 +803,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 537295292b6..e208ad61e33 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -40,7 +40,7 @@ StatusOr> BufferAllocations::Builder::Build( const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { const int64 num_buffers = buffer_assignment->Allocations().size(); - auto buffer_allocations = WrapUnique(new BufferAllocations( + auto buffer_allocations = absl::WrapUnique(new BufferAllocations( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 5780e0af406..8b0426aa27f 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 2fd2206324e..88f0b4d71c9 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,7 +28,7 @@ ForThunk::ForThunk(const int64 loop_limit, const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), loop_limit_(loop_limit), - body_thunk_sequence_(MakeUnique( + body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 70608379048..a1fbd8022db 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -144,7 +144,7 @@ Status GpuExecutable::ExecuteThunks( TF_RETURN_IF_ERROR( thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); if (thunk_schedule_->Depended(thunk)) { - auto finish_event = MakeUnique(main_stream->parent()); + auto finish_event = absl::make_unique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index a2f53f84461..44303724bb5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/IR/DataLayout.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -160,9 +161,10 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( if (ShapeUtil::IsTuple(shape)) { return; } - *buffer = MakeUnique(GetByteSizeRequirement(shape)); + *buffer = absl::make_unique( + GetByteSizeRequirement(shape)); (*buffer)->set_destination( - MakeUnique(literal, index)); + absl::make_unique(literal, index)); }); // Give the tree of buffers to the outfeed mananger. The device will fill it @@ -179,7 +181,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( } // namespace xla static std::unique_ptr CreateNVPTXTransferManager() { - return xla::MakeUnique( + return absl::make_unique( /*id=*/stream_executor::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout) .getPointerSize(0 /* default address space */)); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index 17226769302..b9c21e8edb2 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -33,7 +34,7 @@ namespace gpu { namespace { void InitAndStartTimer(std::stack>* timers, se::Stream* stream) { - timers->push(MakeUnique(stream->parent())); + timers->push(absl::make_unique(stream->parent())); stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); } @@ -115,7 +116,7 @@ HloExecutionProfiler::MakeScopedInstructionProfiler( CHECK(hlo_instructions_.insert(hlo_instruction).second) << hlo_instruction->name(); } - return MakeUnique(this, hlo_instruction); + return absl::make_unique(this, hlo_instruction); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 19de37b0fbe..76055ff009c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" @@ -59,8 +59,8 @@ GpuHloOrdering::GpuHloOrdering( : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = - MakeUnique>(thunk_launch_order); + entry_sequence_ = absl::make_unique>( + thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -75,7 +75,7 @@ GpuHloOrdering::GpuHloOrdering( // same-stream predecessors of each instruction. // Compute the set of all instructions we will want to set reachability on. - auto predecessor_map = MakeUnique( + auto predecessor_map = absl::make_unique( module->entry_computation()->MakeInstructionPostOrder()); // The most recently visited instruction per stream. @@ -208,7 +208,7 @@ StatusOr> HloSchedule::Build( BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); } - schedule->hlo_ordering_ = MakeUnique( + schedule->hlo_ordering_ = absl::make_unique( &module, stream_assignment, schedule->thunk_launch_order_); return std::move(schedule); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index 45f0a1c645b..d4a96cd5b35 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -47,7 +48,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", config); + return absl::make_unique("test_module", config); } HloVec RemoveHlo(const HloVec& input, diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index c5f0cdf6cd5..a4364b0deb6 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" namespace xla { namespace gpu { @@ -24,7 +24,7 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(host_to_device_stream_mu_); if (host_to_device_executor_ == nullptr) { host_to_device_executor_ = executor; - host_to_device_stream_ = MakeUnique(executor); + host_to_device_stream_ = absl::make_unique(executor); host_to_device_stream_->Init(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 71c30e19a2a..dea2a31920e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -30,7 +31,6 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -384,7 +384,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { int64 feature_index_value = feature_index->literal().Get({}); thunk_sequence_->emplace_back( - MakeUnique( + absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -414,7 +414,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); thunk_sequence_->emplace_back( - MakeUnique( + absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -444,19 +444,20 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_grad_offset = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - thunk_sequence_->emplace_back(MakeUnique( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*mean=*/GetAllocationSlice(*custom_call->operand(2)), - /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), - /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_grad_data=*/output_grad_data, - /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + thunk_sequence_->emplace_back( + absl::make_unique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); return Status::OK(); } @@ -476,7 +477,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kForward, /*input_buffer=*/lhs_slice, /*filter_buffer=*/rhs_slice, @@ -490,7 +491,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { backend_config.algorithm(), backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kBackwardInput, /*input_buffer=*/conv_result_slice, /*filter_buffer=*/rhs_slice, @@ -504,7 +505,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { backend_config.algorithm(), backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kBackwardFilter, /*input_buffer=*/lhs_slice, /*filter_buffer=*/conv_result_slice, @@ -577,7 +578,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunks.push_back( BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), fusion)); + absl::make_unique(std::move(thunks), fusion)); std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); @@ -1719,7 +1720,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { thunks.push_back( BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), reduce)); + absl::make_unique(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), {[&](const IrArray::Index& index) { @@ -1761,7 +1762,7 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } - thunk_sequence_->emplace_back(MakeUnique( + thunk_sequence_->emplace_back(absl::make_unique( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } @@ -1793,8 +1794,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( thunks.push_back(std::move(initializer_thunk)); thunks.push_back(BuildKernelThunk(select_and_scatter, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), select_and_scatter)); + thunk_sequence_->emplace_back(absl::make_unique( + std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -2019,7 +2020,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { thunks.push_back(std::move(rng_thunk)); thunks.push_back(std::move(increment_seed_thunk)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), rng)); + absl::make_unique(std::move(thunks), rng)); return Status::OK(); } @@ -2044,7 +2045,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { auto values_destination = GetAllocationSlice(*sort, values_shape_index); if (keys_destination != GetAllocationSlice(*keys)) { - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*keys), /*destination_buffer=*/keys_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); @@ -2052,7 +2053,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { if (values != nullptr && values_destination != GetAllocationSlice(*values)) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*values), /*destination_buffer=*/values_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); @@ -2104,7 +2105,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), sort)); + absl::make_unique(std::move(thunks), sort)); return Status::OK(); } @@ -2131,7 +2132,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (crs->operand_count() == 1) { CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - thunk_sequence_->push_back(MakeUnique( + thunk_sequence_->push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); @@ -2146,17 +2147,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() .GetUniqueSlice(crs, {i}) .ValueOrDie()); - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. - thunks.push_back(MakeUnique(tuple_element_buffers, - GetAllocationSlice(*crs), nullptr)); + thunks.push_back(absl::make_unique( + tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( - MakeUnique(std::move(thunks), crs)); + absl::make_unique(std::move(thunks), crs)); return Status::OK(); } @@ -2390,7 +2391,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique( + return absl::make_unique( non_constant_buffers, llvm_ir::AsString(kernel->getName()), implements_whole_instruction ? inst : nullptr, unroll_factor); } @@ -2399,7 +2400,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return MakeUnique( + return absl::make_unique( /*source_address=*/operand->literal().untyped_data(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2411,7 +2412,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( std::unique_ptr IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique( + return absl::make_unique( /*source_address=*/GetAllocationSlice(*operand), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2431,7 +2432,7 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); }); - return MakeUnique(slices, inst); + return absl::make_unique(slices, inst); } std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( @@ -2448,7 +2449,7 @@ std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( *slice = status_or_slice.ConsumeValueOrDie(); } }); - return MakeUnique(std::move(slices), inst); + return absl::make_unique(std::move(slices), inst); } namespace { @@ -2471,7 +2472,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (inst->opcode() == HloOpcode::kDot) { const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2513,7 +2514,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* rhs = inst->operand(rhs_parameter->parameter_number()); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2530,11 +2531,12 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( std::unique_ptr IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique(inst->fft_type(), inst->fft_length(), - /*input_buffer=*/GetAllocationSlice(*operand), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/operand->shape(), - /*output_shape=*/inst->shape(), inst); + return absl::make_unique( + inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); } StatusOr> IrEmitterUnnested::BuildInitializerThunk( @@ -2584,8 +2586,8 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ArraySlice literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return { - MakeUnique(GetAllocationSlice(*hlo, index), nullptr)}; + return {absl::make_unique(GetAllocationSlice(*hlo, index), + nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2602,7 +2604,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique( + return {absl::make_unique( pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } @@ -2613,7 +2615,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique( + return {absl::make_unique( word, GetAllocationSlice(*hlo, index), nullptr)}; } } @@ -2765,7 +2767,7 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition.ConsumeThunkSequence(), ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2783,8 +2785,8 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique(loop_limit, - ir_emitter_body.ConsumeThunkSequence(), hlo); + return absl::make_unique( + loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); } std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( @@ -2804,7 +2806,7 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( ir_emitter_context_); TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo->operand(1)), GetAllocationSlice(*hlo->operand(2)), diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index e76823ad103..6305396635e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" @@ -95,7 +95,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << "Launching " << kernel->name(); // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; - auto kernel_args = MakeUnique>(); + auto kernel_args = absl::make_unique>(); for (const BufferAllocation* arg : args_) { const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); kernel_args->add_device_memory_argument(buf); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index eb93efc560e..6bd9c58f830 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index ff4ae1f9ef2..cce6e481417 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -205,7 +205,7 @@ std::unique_ptr GetTargetMachine( default: codegen_opt_level = CodeGenOpt::None; } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, Optional(RelocModel), Optional(CMModel), codegen_opt_level)); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 6c1eab4f8c7..5868c1a42e6 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -21,13 +21,13 @@ limitations under the License. #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include +#include "absl/memory/memory.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -690,7 +690,7 @@ StatusOr> NVPTXCompiler::RunBackend( const std::vector cubin = CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); - auto thunk_schedule = MakeUnique( + auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); VLOG(2) << "Printing the thunk schedule..."; @@ -704,7 +704,7 @@ StatusOr> NVPTXCompiler::RunBackend( cost_analysis.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - profile_index_map = MakeUnique(*module); + profile_index_map = absl::make_unique(*module); profile_printer = CreateHloProfilePrinterData(*profile_index_map, cost_analysis); } @@ -813,7 +813,7 @@ se::Platform::Id NVPTXCompiler::PlatformId() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::cuda::kCudaPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc index 4aaf0c9e142..2fa170964e9 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index b22040eee16..98cc21ccac5 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -69,7 +70,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -126,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -236,7 +237,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( + HloInstruction::CreateConstant(absl::make_unique( LiteralUtil::Zero(input->shape().element_type())))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d3fd0544fb6..c927c5ee166 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 0806dd51614..5b6cf2c04d0 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" @@ -119,7 +119,7 @@ int ComputeStreamToAssign( } // namespace std::unique_ptr AssignStreams(const HloModule& module) { - auto stream_assignment = MakeUnique(); + auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = computation.ComputeReachability(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 6f4bb0580e8..3f75d8b5595 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -33,7 +34,7 @@ class StreamAssignmentTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", config); + return absl::make_unique("test_module", config); } // Pre-canned shapes. diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 4fad3f46cf9..db4a33dc564 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -35,13 +35,13 @@ cc_library( "requires-gpu-sm35", ], deps = [ - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -60,6 +60,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +95,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -150,6 +152,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -168,6 +171,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 4b8415fe910..0e84ec7e621 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -32,7 +32,7 @@ std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { debug_options.add_xla_disable_hlo_passes("constant_folding"); config.set_debug_options(debug_options); - return MakeUnique(TestName(), config); + return absl::make_unique(TestName(), config); } void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index ce69e058e64..4550f36fdfc 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index e5958165eff..a06576df7b8 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6c9ae7bada5..6a9ecd9dae7 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index c42e5704a4d..15198865bda 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 8579b1545fd..989b542ff45 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" @@ -25,7 +26,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { auto size = tuple_element_buffers_.size(); - auto tuple_element_buffer_addresses = MakeUnique(size); + auto tuple_element_buffer_addresses = absl::make_unique(size); for (int i = 0; i != size; ++i) { tuple_element_buffer_addresses[i] = buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque(); diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index d81d87e7dc5..828fc2884bd 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -34,9 +34,9 @@ WhileThunk::WhileThunk( // and body_thunk_sequence_ constructors because these SequentialThunks // are logically "part of" this WhileThunk, and shouldn't be profiled // separately from it. - condition_thunk_sequence_(MakeUnique( + condition_thunk_sequence_(absl::make_unique( std::move(*condition_thunk_sequence), nullptr)), - body_thunk_sequence_(MakeUnique( + body_thunk_sequence_(absl::make_unique( std::move(*body_thunk_sequence), nullptr)) {} Status WhileThunk::Initialize(const GpuExecutable& executable, diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index aa89567ee86..31431f115f8 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -84,7 +84,7 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // the module. std::unique_ptr MakeBigGraph() { HloModuleConfig config; - auto module = MakeUnique("BigGraph", config); + auto module = absl::make_unique("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 4005fc0d114..93a922b9046 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" @@ -45,7 +46,7 @@ StatusOr HeapSimulator::MinimumMemoryForModule( // bound, by minimizing the liveness of sub-computations. TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, + HeapSimulator::Run(absl::make_unique(), *module, module_sequence, *points_to_analysis, size_function)); return result.heap_size; } @@ -60,9 +61,10 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Run(absl::make_unique(), + computation, sequence, points_to_analysis, + size_function, HeapSimulator::Options(), + memory_by_computation)); return result.heap_size; } @@ -344,7 +346,7 @@ HeapSimulator::HeapSimulator( const SequentialHloOrdering::HloModuleSequence* module_sequence, const tensorflow::gtl::FlatMap* memory_by_computation) - : no_fragmentation_stats_(MakeUnique()), + : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index b41dc66fe9f..5f85f145657 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -137,7 +138,7 @@ class HeapSimulatorTracker { const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { HloModuleConfig config; - module_ = MakeUnique(name, config); + module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -146,8 +147,8 @@ class HeapSimulatorTracker { // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. auto zero_size = [](const BufferValue& buffer) { return 0; }; - auto algorithm = MakeUnique( - MakeUnique(&actual_calls_)); + auto algorithm = absl::make_unique( + absl::make_unique(&actual_calls_)); result_ = HeapSimulator::Run( std::move(algorithm), *module_->entry_computation(), instruction_sequence, *points_to_analysis_, zero_size) @@ -156,7 +157,7 @@ class HeapSimulatorTracker { explicit HeapSimulatorTracker(const string& name) { HloModuleConfig config; - module_ = MakeUnique(name, config); + module_ = absl::make_unique(name, config); } // Similar to the single entry computation constructor above, but runs the @@ -182,8 +183,8 @@ class HeapSimulatorTracker { auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; - auto algorithm = MakeUnique( - MakeUnique(&actual_calls_)); + auto algorithm = absl::make_unique( + absl::make_unique(&actual_calls_)); result_ = HeapSimulator::Run(std::move(algorithm), *module_, module_sequence, *points_to_analysis_, size_fn) .ConsumeValueOrDie(); @@ -675,7 +676,8 @@ class HeapAlgorithmTestBase : public ::testing::Test { const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); + buffers_.emplace_back( + absl::make_unique(id, const0, ShapeIndex{})); return buffers_.back().get(); } @@ -724,7 +726,8 @@ class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {}; TEST_F(DecreasingSizeRunsHeapTest, Empty) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Finish(); EXPECT_EQ(call_sequence, CallSequence({ {kFinish, nullptr}, @@ -733,7 +736,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Empty) { TEST_F(DecreasingSizeRunsHeapTest, Simple) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 30); @@ -760,7 +764,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Simple) { TEST_F(DecreasingSizeRunsHeapTest, Mixed) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Free(buffer_b_, 20); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index e8a4b034b43..0ca489846e7 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -457,7 +457,7 @@ StatusOr> HloAliasAnalysis::Run( VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); - auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); + auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false, diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index db853360f1f..bae78c94bdc 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -24,9 +24,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -57,8 +57,8 @@ std::unique_ptr HloComputation::Builder::Build( HloInstruction* root = root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, - root, fusion_instruction_)); + return absl::WrapUnique(new HloComputation( + name_, parameter_count, &instructions_, root, fusion_instruction_)); } HloComputation::HloComputation( @@ -494,9 +494,9 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( @@ -675,7 +675,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); - auto result = MakeUnique(all); + auto result = absl::make_unique(all); std::vector inputs; for (const HloInstruction* hlo : all) { @@ -830,7 +830,7 @@ std::unique_ptr HloComputation::CloneWithReplacements( HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { - context_ptr = MakeUnique(parent(), suffix); + context_ptr = absl::make_unique(parent(), suffix); context = context_ptr.get(); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 7229031c0c7..6dddda1ca89 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -38,7 +39,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may // be revised. - auto evaluator = MakeUnique(/*max_loop_iterations=*/0); + auto evaluator = absl::make_unique(/*max_loop_iterations=*/0); XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 83adaddba42..c4e27dc558e 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -319,7 +319,7 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, *padding_config.add_dimensions() = padding_config_dim; HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( + HloInstruction::CreateConstant(absl::make_unique( LiteralUtil::Zero(operand->shape().element_type())))); return MakePadHlo(operand, zero, padding_config); } @@ -329,7 +329,7 @@ StatusOr BroadcastZeros( ArraySlice broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index d123dbb1a04..a8de285d16f 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 90fbaa37c5a..406d712ec67 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index bbfb0c253f5..9b150579298 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -886,7 +886,7 @@ StatusOr> HloDataflowAnalysis::Run( VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis( module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 26e3736e012..3b5cde2996c 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 9e096320db5..edf0073f309 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" @@ -25,14 +26,14 @@ namespace xla { /* static */ StatusOr> HloDomainMap::Create( HloComputation* computation, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); TF_RETURN_IF_ERROR(domain_map->Populate(computation)); return std::move(domain_map); } /* static */ StatusOr> HloDomainMap::Create( HloModule* module, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); for (HloComputation* computation : module->computations()) { TF_RETURN_IF_ERROR(domain_map->Populate(computation)); } @@ -56,14 +57,14 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { // both sides. for (HloInstruction* operand : instruction->unique_operands()) { if (IsDomainInstruction(operand)) { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); domain->enter_domains.insert(operand); domain->exit_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } if (instruction == instruction->parent()->root_instruction()) { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); domain->enter_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } @@ -143,7 +144,7 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction, StatusOr> HloDomainMap::CreateDomain( HloInstruction* instruction) const { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); domain->instructions = MakeNonDomainInstructions(domain->reach_set); return std::move(domain); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 70271be3043..7d48be15cfd 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -80,7 +81,7 @@ class OpNameMetadata : public DomainMetadata { explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} std::unique_ptr Clone() const override { - return MakeUnique(opname_); + return absl::make_unique(opname_); } tensorflow::StringPiece Kind() const override { return KindName(); } @@ -110,9 +111,9 @@ std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, return nullptr; } std::unique_ptr operand_side_metadata = - MakeUnique(operand->metadata().op_name()); + absl::make_unique(operand->metadata().op_name()); std::unique_ptr user_side_metadata = - MakeUnique(instruction->metadata().op_name()); + absl::make_unique(instruction->metadata().op_name()); return HloInstruction::CreateDomain(operand->shape(), operand, std::move(operand_side_metadata), std::move(user_side_metadata)); @@ -474,8 +475,8 @@ ENTRY entry { TEST_F(HloDomainTest, DumpParseNullSharding) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {}); - auto sharding_md_0 = MakeUnique(nullptr); - auto sharding_md_1 = MakeUnique(nullptr); + auto sharding_md_0 = absl::make_unique(nullptr); + auto sharding_md_1 = absl::make_unique(nullptr); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 0455c7f41ac..35d9e799df6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -24,13 +24,13 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -96,7 +96,7 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -126,7 +126,7 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -139,44 +139,57 @@ StatusOr> Compare( HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique>(this); - typed_visitors_[U8] = MakeUnique>(this); - typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "U16."); - }); - typed_visitors_[U32] = MakeUnique>(this); - typed_visitors_[U64] = MakeUnique>(this); - typed_visitors_[S8] = MakeUnique>(this); - typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "S16."); - }); - typed_visitors_[S32] = MakeUnique>(this); - typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[PRED] = + absl::make_unique>(this); + typed_visitors_[U8] = + absl::make_unique>(this); + typed_visitors_[U16] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); + }); + typed_visitors_[U32] = + absl::make_unique>(this); + typed_visitors_[U64] = + absl::make_unique>(this); + typed_visitors_[S8] = absl::make_unique>(this); + typed_visitors_[S16] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); + }); + typed_visitors_[S32] = + absl::make_unique>(this); + typed_visitors_[S64] = + absl::make_unique>(this); typed_visitors_[F16] = - MakeUnique>(this); - typed_visitors_[F32] = MakeUnique>(this); - typed_visitors_[F64] = MakeUnique>(this); - typed_visitors_[C64] = MakeUnique>(this); + absl::make_unique>(this); + typed_visitors_[F32] = + absl::make_unique>(this); + typed_visitors_[F64] = + absl::make_unique>(this); + typed_visitors_[C64] = + absl::make_unique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. typed_visitors_[BF16] = - MakeUnique>(this); + absl::make_unique>(this); - typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); - }); - typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); - }); + typed_visitors_[TUPLE] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); + }); } template @@ -956,7 +969,7 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = MakeUnique( + evaluated_[get_tuple_element] = absl::make_unique( ShapeUtil::GetTupleElementShape(operand->shape(), index)); return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, /*dest_shape_index=*/{}, @@ -1158,10 +1171,11 @@ StatusOr> EvaluateSortInternal( result_keys.push_back(key_value.first); result_values.push_back(key_value.second); } - auto result_keys_literal = MakeUnique(keys_literal.shape()); + auto result_keys_literal = absl::make_unique(keys_literal.shape()); result_keys_literal->PopulateR1( tensorflow::gtl::ArraySlice(result_keys)); - auto result_values_literal = MakeUnique(values_literal.shape()); + auto result_values_literal = + absl::make_unique(values_literal.shape()); result_values_literal->PopulateR1( tensorflow::gtl::ArraySlice(result_values)); return std::make_pair(std::move(result_keys_literal), @@ -1176,8 +1190,9 @@ StatusOr> EvaluateSortInternal( } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto keys_result_literal = MakeUnique(keys_literal.shape()); - auto values_result_literal = MakeUnique(values_literal.shape()); + auto keys_result_literal = absl::make_unique(keys_literal.shape()); + auto values_result_literal = + absl::make_unique(values_literal.shape()); int64 r1_length = keys_literal.shape().dimensions(1); for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto keys_r1_slice, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index a4c37ef3282..7588916de50 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -226,7 +226,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeUtil::HumanString(operand->shape()).c_str()); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { return unary_op(operand_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 1394be68e4d..4b8e6260ac8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -52,7 +53,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, public HloVerifiedTestBase { protected: HloEvaluatorTest() : use_bfloat16_(GetParam()) { - evaluator_ = MakeUnique(); + evaluator_ = absl::make_unique(); } std::unique_ptr Evaluate( @@ -523,7 +524,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { std::unique_ptr result = Evaluate(); - auto expected_array = MakeUnique>(8, 5, 1, 1); + auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); (*expected_array)(1, 0, 0, 0) = 1.0f; (*expected_array)(1, 2, 0, 0) = 2.0f; @@ -547,7 +548,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique>(4, 3); + auto input_array = absl::make_unique>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = @@ -568,7 +569,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { std::unique_ptr result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } - auto expected_array = MakeUnique>(1, 5); + auto expected_array = absl::make_unique>(1, 5); (*expected_array)(0, 0) = 7.0f; (*expected_array)(0, 1) = 2.718f; (*expected_array)(0, 2) = 2.718f; @@ -588,7 +589,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique>(4, 3); + auto input_array = absl::make_unique>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = @@ -612,7 +613,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { std::unique_ptr result = Evaluate(); - auto expected_array = MakeUnique>(0, 9); + auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -628,7 +629,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // { 3 }, // { 4 }, // } - auto lhs_array = MakeUnique>(4, 1); + auto lhs_array = absl::make_unique>(4, 1); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = @@ -679,7 +680,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique>(3, 2); + auto rhs_array = absl::make_unique>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = @@ -710,7 +711,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto lhs_array = MakeUnique>(4, 3); + auto lhs_array = absl::make_unique>(4, 3); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = @@ -722,7 +723,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique>(3, 2); + auto rhs_array = absl::make_unique>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = @@ -1297,7 +1298,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1339,7 +1340,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1390,7 +1391,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1511,7 +1512,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { // { 9, 10, 11, 12, 13 }, // { 17, 18, 19, 20, 21 }, // } - auto operand_array = MakeUnique>(3, 5); + auto operand_array = absl::make_unique>(3, 5); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1544,7 +1545,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique>(2, 4); + auto operand_array = absl::make_unique>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1580,7 +1581,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique>(2, 4); + auto operand_array = absl::make_unique>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1614,7 +1615,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1651,7 +1652,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); auto operand_literal2 = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1687,7 +1688,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index a7c5d71da01..83d7b404f0b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -897,7 +898,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice out_index) { @@ -1054,7 +1055,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(result_val); }; - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); TF_RETURN_IF_ERROR(result->PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); @@ -1128,7 +1129,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } } - auto result = MakeUnique(dot->shape()); + auto result = absl::make_unique(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice result_index) { ElementwiseT result_val = static_cast(0); @@ -1177,7 +1178,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = MakeUnique(pad->shape()); + auto result = absl::make_unique(pad->shape()); TF_RETURN_IF_ERROR(result->Populate( [&scalar](tensorflow::gtl::ArraySlice multi_index) { return scalar; @@ -1342,7 +1343,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = MakeUnique(map->shape()); + auto result = absl::make_unique(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR(result->Populate( @@ -1456,7 +1457,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess(a, b); }); - auto result_literal = MakeUnique(keys_literal.shape()); + auto result_literal = absl::make_unique(keys_literal.shape()); result_literal->PopulateR1( tensorflow::gtl::ArraySlice(result_data)); VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); @@ -1468,7 +1469,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = MakeUnique(keys_literal.shape()); + auto result_literal = absl::make_unique(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, @@ -1542,7 +1543,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique(reduce->shape()); + auto result = absl::make_unique(reduce->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -1623,7 +1624,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = MakeUnique(select_and_scatter->shape()); + auto result = absl::make_unique(select_and_scatter->shape()); // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate( @@ -1759,7 +1760,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique(reduce_window->shape()); + auto result = absl::make_unique(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice output_index) { @@ -2412,7 +2413,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value>::type* = nullptr> Status HandleIota(HloInstruction* iota) { - auto result = MakeUnique(iota->shape()); + auto result = absl::make_unique(iota->shape()); auto data = result->data(); std::iota(data.begin(), data.end(), 0); parent_->evaluated_[iota] = std::move(result); @@ -2494,7 +2495,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector operand_indices(start.size()); - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { @@ -2580,7 +2581,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -2618,7 +2619,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index f5544017875..de3d7a16775 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -50,7 +51,7 @@ std::unique_ptr CreateHloProfilePrinterData( size_t profile_counters_size = hlo_profile_index_map.total_count(); std::unique_ptr profile_printer_data = - MakeUnique(); + absl::make_unique(); profile_printer_data->set_profile_counters_size(profile_counters_size); profile_printer_data->mutable_computation_infos()->Reserve( hlo_profile_index_map.computation_count()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 2b812135095..e3d6b2e753b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,10 +22,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -225,7 +225,7 @@ StatusOr> HloInstruction::CreateFromProto( Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); } else { - instruction = MakeUnique(proto.shape()); + instruction = absl::make_unique(proto.shape()); } break; } @@ -392,7 +392,8 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_gather_dimension_numbers()) << "Gather instruction should have GatherDimensionNumbers set."; std::unique_ptr gather_dimension_numbers = - MakeUnique(proto.gather_dimension_numbers()); + absl::make_unique( + proto.gather_dimension_numbers()); std::vector gather_slice_sizes; for (int64 bound : proto.gather_slice_sizes()) { gather_slice_sizes.push_back(bound); @@ -410,15 +411,16 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Scatter instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - auto scatter_dimension_numbers = MakeUnique( - proto.scatter_dimension_numbers()); + auto scatter_dimension_numbers = + absl::make_unique( + proto.scatter_dimension_numbers()); instruction = CreateScatter(proto.shape(), operands(0), operands(1), operands(2), computations(0), *scatter_dimension_numbers); break; } default: { - instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) << "No instruction with id " << operand_id; @@ -449,7 +451,7 @@ StatusOr> HloInstruction::CreateFromProto( if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = - MakeUnique(proto.dot_dimension_numbers()); + absl::make_unique(proto.dot_dimension_numbers()); } if (proto.has_sharding()) { @@ -463,34 +465,36 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - return MakeUnique(parameter_number, shape, name); + return absl::make_unique(parameter_number, shape, + name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - return MakeUnique(tag, operand); + return absl::make_unique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - return MakeUnique(std::move(literal)); + return absl::make_unique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateIota( const Shape& shape) { - return WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); + return absl::WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - return MakeUnique(shape, operand, index); + return absl::make_unique(shape, operand, + index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) { - return MakeUnique(shape, distribution, parameters); + return absl::make_unique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( @@ -500,7 +504,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // It is impossible to copy an opaque shape, we don't know how big it is. CHECK(!ShapeUtil::IsOpaque(shape)); } - auto instruction = WrapUnique(new HloInstruction(opcode, shape)); + auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -605,31 +609,33 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* map_computation) { - return MakeUnique(shape, operands, map_computation); + return absl::make_unique(shape, operands, map_computation); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count) { - return MakeUnique( + return absl::make_unique( shape, lhs, rhs, window, dimension_numbers, feature_group_count); } /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length) { - return MakeUnique(shape, operand, fft_type, fft_length); + return absl::make_unique(shape, operand, fft_type, + fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); instruction->dot_dimension_numbers_ = - MakeUnique(dimension_numbers); + absl::make_unique(dimension_numbers); return instruction; } @@ -638,10 +644,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = MakeUnique(); + instruction->dot_dimension_numbers_ = + absl::make_unique(); instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); return instruction; @@ -652,7 +660,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - return MakeUnique( + return absl::make_unique( shape, operand, exponent_bits, mantissa_bits); } @@ -663,7 +671,7 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, const tensorflow::gtl::optional& all_reduce_id) { - return MakeUnique( + return absl::make_unique( shape, operands, reduce_computation, replica_group_ids, barrier, all_reduce_id); } @@ -672,28 +680,29 @@ HloInstruction::CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice operands, const std::vector& replica_groups, tensorflow::StringPiece barrier) { - return MakeUnique(shape, operands, replica_groups, - barrier); + return absl::make_unique(shape, operands, + replica_groups, barrier); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { - return MakeUnique(infeed_shape, token_operand, config); + return absl::make_unique(infeed_shape, token_operand, + config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { - return MakeUnique(outfeed_shape, operand, - token_operand, outfeed_config); + return absl::make_unique( + outfeed_shape, operand, token_operand, outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique(operand, token, channel_id, - is_host_transfer); + return absl::make_unique(operand, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( @@ -701,14 +710,15 @@ HloInstruction::CreateCrossReplicaSum( auto send_operand = DynCast(operand); CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - return MakeUnique(send_operand, is_host_transfer); + return absl::make_unique(send_operand, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique(shape, token, channel_id, - is_host_transfer); + return absl::make_unique(shape, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( @@ -716,19 +726,20 @@ HloInstruction::CreateCrossReplicaSum( auto recv_operand = DynCast(operand); CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - return MakeUnique(recv_operand, is_host_transfer); + return absl::make_unique(recv_operand, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - return MakeUnique(shape, operand, dimensions); + return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateAfterAll( tensorflow::gtl::ArraySlice operands) { CHECK(!operands.empty()); - auto instruction = WrapUnique( + auto instruction = absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { instruction->AppendOperand(operand); @@ -737,14 +748,15 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr HloInstruction::CreateToken() { - return WrapUnique( + return absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); } /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); // Body comes before condition computation in the vector. instruction->called_computations_.push_back(body); @@ -757,7 +769,7 @@ HloInstruction::CreateCrossReplicaSum( HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); instruction->AppendOperand(pred); instruction->AppendOperand(true_computation_arg); instruction->AppendOperand(false_computation_arg); @@ -774,15 +786,15 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return MakeUnique(shape, operand, start_indices, - limit_indices, strides); + return absl::make_unique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return MakeUnique(shape, operand, start_indices, - slice_sizes); + return absl::make_unique( + shape, operand, start_indices, slice_sizes); } /* static */ std::unique_ptr @@ -790,8 +802,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); instruction->AppendOperand(operand); instruction->AppendOperand(update); instruction->AppendOperand(start_indices); @@ -801,12 +813,14 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension) { - return MakeUnique(shape, operands, dimension); + return absl::make_unique(shape, operands, + dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( const Shape& shape, HloInstruction* operand) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); instruction->AppendOperand(operand); return instruction; } @@ -815,7 +829,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction::CreateBitcastConvert(const Shape& shape, HloInstruction* operand) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); instruction->AppendOperand(operand); return instruction; } @@ -824,7 +838,7 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloReduceInstruction( + auto instruction = absl::WrapUnique(new HloReduceInstruction( shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); return std::move(instruction); } @@ -838,15 +852,15 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, all_args.reserve(operands.size() * 2); all_args.insert(all_args.end(), operands.begin(), operands.end()); all_args.insert(all_args.end(), init_values.begin(), init_values.end()); - return MakeUnique(shape, all_args, dimensions_to_reduce, - reduce_computation); + return absl::make_unique( + shape, all_args, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - return MakeUnique(shape, operand, init_value, - window, reduce_computation); + return absl::make_unique( + shape, operand, init_value, window, reduce_computation); } /* static */ std::unique_ptr @@ -855,7 +869,7 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - return MakeUnique( + return absl::make_unique( shape, operand, scale, offset, epsilon, feature_index); } @@ -864,7 +878,7 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - return MakeUnique( + return absl::make_unique( shape, operand, scale, offset, mean, variance, epsilon, feature_index); } @@ -874,9 +888,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - return MakeUnique(shape, operand, scale, mean, - variance, grad_output, epsilon, - feature_index); + return absl::make_unique( + shape, operand, scale, mean, variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -884,15 +898,15 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - return MakeUnique( + return absl::make_unique( shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return MakeUnique(shape, operand, - broadcast_dimensions); + return absl::make_unique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -950,8 +964,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - return MakeUnique(shape, operand, padding_value, - padding_config); + return absl::make_unique(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr HloInstruction::CreateReshape( @@ -960,7 +974,8 @@ HloInstruction::CreateBroadcastSequence( ShapeUtil::ElementsIn(operand->shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(operand->shape()); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; } @@ -968,26 +983,27 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - return MakeUnique(shape, operand, dimensions); + return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, HloInstruction* values) { - return MakeUnique(shape, dimension, keys, values); + return absl::make_unique(shape, dimension, keys, values); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - return MakeUnique(shape, fusion_kind, fused_root); + return absl::make_unique(shape, fusion_kind, + fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation) { - return MakeUnique(shape, fusion_kind, operands, - fusion_computation); + return absl::make_unique(shape, fusion_kind, operands, + fusion_computation); } void HloInstruction::set_single_sharding(const HloSharding& sharding) { @@ -1045,7 +1061,7 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* computation) { std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -1056,15 +1072,15 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) { - return MakeUnique(shape, operands, - custom_call_target); + return absl::make_unique(shape, operands, + custom_call_target); } /* static */ std::unique_ptr HloInstruction::CreateHostCompute( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - return MakeUnique(shape, operands, channel_name, - cost_estimate_ns); + return absl::make_unique( + shape, operands, channel_name, cost_estimate_ns); } /* static */ std::unique_ptr HloInstruction::CreateTuple( @@ -1081,8 +1097,8 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice slice_sizes) { - return MakeUnique(shape, operand, start_indices, - gather_dim_numbers, slice_sizes); + return absl::make_unique( + shape, operand, start_indices, gather_dim_numbers, slice_sizes); } /* static */ std::unique_ptr HloInstruction::CreateScatter( @@ -1090,16 +1106,17 @@ bool HloInstruction::HasSideEffect() const { HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, const ScatterDimensionNumbers& scatter_dim_numbers) { - return MakeUnique(shape, operand, scatter_indices, - updates, update_computation, - scatter_dim_numbers); + return absl::make_unique( + shape, operand, scatter_indices, updates, update_computation, + scatter_dim_numbers); } /* static */ std::unique_ptr HloInstruction::CreateDomain( const Shape& shape, HloInstruction* operand, std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); instruction->operand_side_metadata_ = std::move(operand_side_metadata); instruction->user_side_metadata_ = std::move(user_side_metadata); instruction->AppendOperand(operand); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8d8f149ee37..30dbabfced0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -32,6 +32,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -1052,7 +1053,7 @@ class HloInstruction { // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = MakeUnique(sharding); + sharding_ = absl::make_unique(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 0751aacdd6d..79a5e7481d7 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -90,7 +91,7 @@ HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), feature_index()); } @@ -112,7 +113,7 @@ HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } @@ -134,7 +135,7 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } @@ -176,8 +177,8 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], fft_type_, - fft_length_); + return absl::make_unique(shape, new_operands[0], fft_type_, + fft_length_); } HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, @@ -231,8 +232,8 @@ std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(new_operands[0], new_operands[1], - channel_id(), is_host_transfer()); + return absl::make_unique( + new_operands[0], new_operands[1], channel_id(), is_host_transfer()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, @@ -249,7 +250,7 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( Cast(new_operands[0]), is_host_transfer()); } @@ -270,7 +271,7 @@ std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), is_host_transfer()); } @@ -292,7 +293,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( Cast(new_operands[0]), is_host_transfer()); } @@ -355,7 +356,7 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* /*context*/) const { - return MakeUnique( + return absl::make_unique( shape, new_operands, to_apply(), replica_group_ids(), cross_replica_sum_barrier(), all_reduce_id()); } @@ -391,7 +392,7 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* /*context*/) const { - return MakeUnique( + return absl::make_unique( shape, new_operands, replica_groups(), cross_replica_sum_barrier()); } @@ -455,8 +456,8 @@ std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloConcatenateInstruction::HloConcatenateInstruction( @@ -495,8 +496,8 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(shape, new_operands, - dimensions(0)); + return absl::make_unique(shape, new_operands, + dimensions(0)); } HloReduceInstruction::HloReduceInstruction( @@ -540,8 +541,8 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands, dimensions(), - to_apply()); + return absl::make_unique(shape, new_operands, + dimensions(), to_apply()); } HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, @@ -581,7 +582,8 @@ std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; - return MakeUnique(shape, dimensions(0), keys, values); + return absl::make_unique(shape, dimensions(0), keys, + values); } HloTransposeInstruction::HloTransposeInstruction( @@ -634,8 +636,8 @@ HloTransposeInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloBroadcastInstruction::HloBroadcastInstruction( @@ -673,8 +675,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloMapInstruction::HloMapInstruction( @@ -731,7 +733,7 @@ std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(shape, new_operands, to_apply()); + return absl::make_unique(shape, new_operands, to_apply()); } HloSliceInstruction::HloSliceInstruction( @@ -793,8 +795,8 @@ std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], slice_starts_, - slice_limits_, slice_strides_); + return absl::make_unique( + shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) @@ -846,7 +848,7 @@ HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(literal_->CloneToUnique()); + return absl::make_unique(literal_->CloneToUnique()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -1340,8 +1342,8 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation = module->AddEmbeddedComputation( fused_instructions_computation()->Clone("clone", context)); } - return MakeUnique(shape, fusion_kind(), new_operands, - new_fused_computation); + return absl::make_unique( + shape, fusion_kind(), new_operands, new_fused_computation); } Status HloFusionInstruction::DeduplicateFusionOperands() { @@ -1400,7 +1402,8 @@ std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(shape, distribution_, new_operands); + return absl::make_unique(shape, distribution_, + new_operands); } HloParameterInstruction::HloParameterInstruction(int64 parameter_number, @@ -1436,7 +1439,8 @@ HloParameterInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique(parameter_number_, shape, name()); + return absl::make_unique(parameter_number_, shape, + name()); } HloGetTupleElementInstruction::HloGetTupleElementInstruction( @@ -1472,8 +1476,8 @@ HloGetTupleElementInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - tuple_index()); + return absl::make_unique( + shape, new_operands[0], tuple_index()); } HloReducePrecisionInstruction::HloReducePrecisionInstruction( @@ -1515,7 +1519,7 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], exponent_bits(), mantissa_bits()); } @@ -1556,8 +1560,8 @@ std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(infeed_shape(), new_operands[0], - infeed_config()); + return absl::make_unique( + infeed_shape(), new_operands[0], infeed_config()); } HloOutfeedInstruction::HloOutfeedInstruction( @@ -1601,8 +1605,8 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(outfeed_shape(), new_operands[0], - new_operands[1], outfeed_config()); + return absl::make_unique( + outfeed_shape(), new_operands[0], new_operands[1], outfeed_config()); } HloConvolutionInstruction::HloConvolutionInstruction( @@ -1672,7 +1676,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], window(), convolution_dimension_numbers_, feature_group_count_); } @@ -1717,7 +1721,7 @@ HloReduceWindowInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], window(), to_apply()); } @@ -1766,7 +1770,7 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], select(), window(), new_operands[1], new_operands[2], scatter()); } @@ -1841,8 +1845,8 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - auto cloned = MakeUnique(shape, new_operands, - custom_call_target()); + auto cloned = absl::make_unique( + shape, new_operands, custom_call_target()); if (window_ != nullptr) { cloned->set_window(*window_); } @@ -1883,7 +1887,7 @@ HloHostComputeInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - return MakeUnique( + return absl::make_unique( shape, new_operands, channel_name_, cost_estimate_ns_); } @@ -1921,8 +1925,8 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands[0], new_operands[1], - padding_config_); + return absl::make_unique(shape, new_operands[0], + new_operands[1], padding_config_); } HloDynamicSliceInstruction::HloDynamicSliceInstruction( @@ -1961,7 +1965,7 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } @@ -1973,7 +1977,7 @@ HloGatherInstruction::HloGatherInstruction( AppendOperand(operand); AppendOperand(start_indices); gather_dimension_numbers_ = - MakeUnique(gather_dim_numbers); + absl::make_unique(gather_dim_numbers); absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } @@ -2047,7 +2051,7 @@ std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], gather_dimension_numbers(), gather_slice_sizes()); } @@ -2063,7 +2067,7 @@ HloScatterInstruction::HloScatterInstruction( AppendOperand(updates); AppendComputation(update_computation); scatter_dimension_numbers_ = - MakeUnique(scatter_dim_numbers); + absl::make_unique(scatter_dim_numbers); } string HloScatterInstruction::ScatterDimensionNumbersToString() const { @@ -2134,7 +2138,7 @@ std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), scatter_dimension_numbers()); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 803dbeabeb0..19b69c21711 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -1080,7 +1081,7 @@ class HloCustomCallInstruction : public HloInstruction { } void set_window(const Window& window) override { - window_ = MakeUnique(window); + window_ = absl::make_unique(window); } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -1091,7 +1092,7 @@ class HloCustomCallInstruction : public HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = - MakeUnique(dnums); + absl::make_unique(dnums); } const string& custom_call_target() const { return custom_call_target_; } // Returns a serialized representation of this instruction. diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 43c41ece6ef..18f17b75aed 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -296,7 +296,7 @@ StatusOr> HloLivenessAnalysis::Run( VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module)); liveness_analysis->RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 76f82360482..d60b76d63f8 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -23,8 +23,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -275,7 +275,7 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique(proto.name(), module_config); + auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -508,7 +508,7 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique(name_ + "-" + suffix, config_); + auto module = absl::make_unique(name_ + "-" + suffix, config_); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 07a8c798dbe..f9708283eb4 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/str_util.h" diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 10bf9ffd6c1..3b512bf0f81 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { /* static */ StatusOr> HloModuleGroupMetadata::Build(const std::vector& modules) { - auto metadata = MakeUnique(modules); + auto metadata = absl::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -383,7 +383,7 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - tensorflow::MakeUnique>()); + absl::make_unique>()); auto companion_set = companion_sets_.back().get(); companion_set->insert(instruction1); companion_set->insert(instruction2); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 0dc56761482..4f11ce322e6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -332,7 +332,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = MakeUnique(post_order); + auto reachability = absl::make_unique(post_order); for (HloInstruction* hlo : post_order) { reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 236f4500860..209ad5e58c9 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index e48c9d2c411..3768da8a731 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -305,7 +306,7 @@ bool HloParser::ParseHloModule() { return false; } - module_ = MakeUnique(name, config_); + module_ = absl::make_unique(name, config_); return ParseComputations(); } @@ -358,7 +359,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = MakeUnique(name); + auto builder = absl::make_unique(name); LocTy shape_loc = nullptr; Shape shape; @@ -1511,14 +1512,14 @@ bool HloParser::ParseDomain(DomainData* domain) { return false; } if (*kind == ShardingMetadata::KindName()) { - auto entry_sharding_ptr = MakeUnique( + auto entry_sharding_ptr = absl::make_unique( HloSharding::FromProto(*entry_sharding).ValueOrDie()); - auto exit_sharding_ptr = MakeUnique( + auto exit_sharding_ptr = absl::make_unique( HloSharding::FromProto(*exit_sharding).ValueOrDie()); domain->entry_metadata = - MakeUnique(std::move(entry_sharding_ptr)); + absl::make_unique(std::move(entry_sharding_ptr)); domain->exit_metadata = - MakeUnique(std::move(exit_sharding_ptr)); + absl::make_unique(std::move(exit_sharding_ptr)); } else { return TokenError(StrCat("unsupported domain kind: ", *kind)); } @@ -1927,7 +1928,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = MakeUnique(shape); + *literal = absl::make_unique(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 3f3a51215e3..5f0f75c480e 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_lexer.h" diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index a42d7e59fed..3bb1342aa37 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index b2725e2918c..8f3ae9c6212 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -233,7 +233,7 @@ StatusOr>> HloRunner::ExecuteReplicated( int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(MakeUnique(executor)); + streams.push_back(absl::make_unique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -260,7 +260,7 @@ StatusOr>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = MakeUnique( + pool = absl::make_unique( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -291,7 +291,7 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = MakeUnique(); + auto literal = absl::make_unique(); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, options.outfeed_shape, literal.get())); if (options.outfeed_values != nullptr) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index a2c1d39d0d4..4e19557f829 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -121,9 +122,9 @@ std::unique_ptr CloneShardingForDomain( const HloSharding& sharding) { auto single_sharding = sharding.ExtractSingleSharding(); if (!single_sharding) { - return MakeUnique(sharding); + return absl::make_unique(sharding); } - return MakeUnique(*single_sharding); + return absl::make_unique(*single_sharding); } Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, @@ -318,9 +319,9 @@ std::unique_ptr CreateDomain(HloInstruction* instruction, : "None"); std::unique_ptr operand_side_metadata = - MakeUnique(std::move(real_operand_sharding)); + absl::make_unique(std::move(real_operand_sharding)); std::unique_ptr user_side_metadata = - MakeUnique(std::move(real_instruction_sharding)); + absl::make_unique(std::move(real_instruction_sharding)); return HloInstruction::CreateDomain(operand->shape(), operand, std::move(operand_side_metadata), std::move(user_side_metadata)); @@ -357,9 +358,9 @@ StatusOr> ExtractOriginalCommonSharding( std::unique_ptr ShardingMetadata::Clone() const { std::unique_ptr sharding; if (sharding_ != nullptr) { - sharding = MakeUnique(*sharding_); + sharding = absl::make_unique(*sharding_); } - return MakeUnique(std::move(sharding)); + return absl::make_unique(std::move(sharding)); } bool ShardingMetadata::Matches(const DomainMetadata& other) const { diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 7fd99fc9305..14703aaf64b 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index c942fab08e1..9e54b54b26a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/shape_inference.h" namespace xla { @@ -128,11 +129,11 @@ class HloVerifier : public HloPassInterface { // Uses standard shape inference. explicit HloVerifier() : shape_verifier_factory_( - [] { return MakeUnique(false); }) {} + [] { return absl::make_unique(false); }) {} explicit HloVerifier(bool allow_mixed_precision) : shape_verifier_factory_([allow_mixed_precision] { - return MakeUnique(allow_mixed_precision); + return absl::make_unique(allow_mixed_precision); }) {} // Uses custom shape verification. diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 32937b33b37..5695bc24205 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 8652599dc6d..581f8d2e92b 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -12,12 +12,11 @@ cc_library( srcs = ["interpreter_transfer_manager.cc"], hdrs = ["interpreter_transfer_manager.h"], deps = [ - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -32,8 +31,6 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", @@ -54,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains compiler registration ) @@ -79,7 +77,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", @@ -91,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 9f8f4bda875..bb69cb9c47f 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -69,8 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module), - xla::MakeUnique()); + absl::make_unique( + std::move(hlo_module), absl::make_unique()); return std::move(executable); } @@ -103,11 +103,11 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { - return xla::MakeUnique(); + return absl::make_unique(); }); xla::ComputationPlacer::RegisterComputationPlacer( se::interpreter::kXlaInterpreterPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 8d40c08d555..2259dc1083e 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index d27cd7502f1..7955ee5cf37 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -31,7 +31,7 @@ InterpreterTransferManager::InterpreterTransferManager() static std::unique_ptr CreateInterpreterTransferManager() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 42c2c28997d..e57a9b36723 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" @@ -70,8 +71,8 @@ port::StatusOr XlaInterpreterPlatform::GetExecutor( port::StatusOr> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { - auto executor = MakeUnique( - this, MakeUnique(config.plugin_config)); + auto executor = absl::make_unique( + this, absl::make_unique(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 805fdb2d5bd..c75bffc63d7 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -26,9 +26,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -137,7 +137,7 @@ PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( } auto& buffer_set = buffer_sets_cache_ - .emplace(instruction, MakeUnique()) + .emplace(instruction, absl::make_unique()) .first->second; const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); points_to_set.ForEachElement( @@ -1008,7 +1008,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit // from assigning the same layout to input and output. - return MakeUnique(output_layout); + return absl::make_unique(output_layout); } if (instruction->opcode() == HloOpcode::kReshape) { @@ -1031,13 +1031,13 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique(operand_shape.layout()); + return absl::make_unique(operand_shape.layout()); } if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique(output_layout); + return absl::make_unique(output_layout); } } auto aligned_operand_shape = @@ -1046,7 +1046,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } } @@ -1062,7 +1062,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } return nullptr; @@ -1080,7 +1080,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( !ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { // Assign users the same layout as the operand. - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } if (user->opcode() == HloOpcode::kReshape) { @@ -1103,13 +1103,13 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique(output_shape.layout()); + return absl::make_unique(output_shape.layout()); } if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } } auto aligned_user_shape = @@ -1118,7 +1118,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); - return MakeUnique(user_layout); + return absl::make_unique(user_layout); } } @@ -1134,7 +1134,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); - return MakeUnique(user_layout); + return absl::make_unique(user_layout); } return nullptr; diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 5e02096ee50..597a788c5d7 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/executable.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index d631fb5ee42..eaa09591b72 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -89,7 +90,7 @@ void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index) { CHECK_EQ(logical_buffers_.size(), next_buffer_id_); logical_buffers_.emplace_back( - MakeUnique(instruction, index, next_buffer_id_)); + absl::make_unique(instruction, index, next_buffer_id_)); output_buffers_[std::make_pair(instruction, index)] = logical_buffers_.back().get(); diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index ccb9fb3e3af..7534a3f7e32 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 1dbf540d13d..18d1b7732bb 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -245,7 +245,7 @@ StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options) { - auto config = MakeUnique(program_shape); + auto config = absl::make_unique(program_shape); ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { @@ -326,7 +326,7 @@ StatusOr>> Service::BuildExecutables( if (directory_path.empty() && execution_directory_path.empty()) { continue; } - auto hlo_snapshot = MakeUnique(); + auto hlo_snapshot = absl::make_unique(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { string filename = @@ -409,7 +409,8 @@ Service::ExecuteParallelAndRegisterResult( streams.push_back(std::move(stream)); if (replica == 0 && profile != nullptr) { - timers.push_back(MakeUnique(streams.back()->parent())); + timers.push_back( + absl::make_unique(streams.back()->parent())); streams.back() ->InitTimer(timers.back().get()) .ThenStartTimer(timers.back().get()); @@ -800,7 +801,7 @@ StatusOr> Service::BuildExecutable( module_proto.name().c_str()); // Dump computation proto state if flag is set. - auto hlo_snapshot = MakeUnique(); + auto hlo_snapshot = absl::make_unique(); const string& directory_path = module_config->debug_options().xla_dump_computations_to(); const string& execution_directory_path = @@ -954,7 +955,7 @@ namespace { // shape and DeviceMemoryBase values of the clone are identical to the original. std::unique_ptr CloneShapedBufferOnDevice( const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = MakeUnique( + auto clone = absl::make_unique( shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), shaped_buffer.platform(), device_ordinal); clone->buffers() = shaped_buffer.buffers(); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 7d7dcac10b6..70714ffff06 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index 0fc24366791..d69e6362e91 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -34,7 +35,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { xla::StreamExecutorMemoryAllocator allocator(platform, executors); const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); const int kDeviceOrdinal = 0; - auto scoped_buffer = tensorflow::MakeUnique( + auto scoped_buffer = absl::make_unique( shape, shape, &allocator, kDeviceOrdinal); std::unique_ptr buffer = std::move(scoped_buffer); buffer = nullptr; diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc index c0582c6a2d3..5d1cd1c4422 100644 --- a/tensorflow/compiler/xla/service/stream_pool.cc +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/stream_pool.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -35,7 +35,7 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { if (!stream) { // Create a new stream. - stream = MakeUnique(executor); + stream = absl::make_unique(executor); stream->Init(); VLOG(1) << stream->DebugStreamPointers() << " StreamPool created new stream"; diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 32d368a9042..e0f995fd0d7 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -61,7 +62,7 @@ StatusOr> TransferManager::TransferLiteralFromDevice( if (!s.ok()) { return s; } - return MakeUnique(std::move(literal)); + return absl::make_unique(std::move(literal)); } Status TransferManager::TransferLiteralFromDevice( @@ -120,7 +121,7 @@ StatusOr> TransferManager::TransferArrayFromDevice( if (!s.ok()) { return s; } - return MakeUnique(std::move(literal)); + return absl::make_unique(std::move(literal)); } Status TransferManager::TransferArrayToDevice( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0447807a41b..0c2f2112af5 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -441,7 +442,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( PerInstruction* pi = PerInst(instruction); CHECK(pi->points_to_set == nullptr) << "instruction should not have been present in the map."; - auto set = MakeUnique(&instruction->shape()); + auto set = absl::make_unique(&instruction->shape()); pi->points_to_set = std::move(set); // Return *set using the iterator returned by emplace. return *pi->points_to_set; diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index c74dd648add..186c42ed130 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index c4c958be4a1..c8ff55e7845 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_tree.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -242,7 +243,7 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { ShapeTree> shape_tree{tuple_shape_}; EXPECT_EQ(shape_tree.element({2}).get(), nullptr); - *shape_tree.mutable_element({2}) = MakeUnique(42); + *shape_tree.mutable_element({2}) = absl::make_unique(42); EXPECT_EQ(*shape_tree.element({2}), 42); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index eac8f977fa3..4d5c9efe9ba 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -113,7 +113,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", @@ -128,6 +127,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -145,6 +145,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -188,7 +189,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", @@ -202,6 +202,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -275,6 +276,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -827,7 +829,7 @@ xla_test( timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS, + deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"], ) xla_test( @@ -837,7 +839,7 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS, + deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"], ) xla_test( @@ -888,6 +890,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1070,6 +1073,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1151,6 +1155,7 @@ xla_test_library( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1214,6 +1219,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1289,6 +1295,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1354,6 +1361,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1415,6 +1423,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1754,6 +1763,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1775,6 +1785,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) @@ -1826,6 +1837,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -1852,6 +1864,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index c7b94b5bbaa..74d4d2eb10c 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 59d917054be..2cab3264a7e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -17,12 +17,12 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -546,7 +546,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( int rows, int cols, float offset) { - auto array = MakeUnique>(rows, cols); + auto array = absl::make_unique>(rows, cols); for (int64 row = 0; row < rows; ++row) { for (int64 col = 0; col < cols; ++col) { (*array)(row, col) = col + (row * 1000.0f) + offset; @@ -561,7 +561,7 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, int cols_padded) { CHECK_GE(rows_padded, rows); CHECK_GE(cols_padded, cols); - auto array = MakeUnique>(rows_padded, cols_padded, 0.0); + auto array = absl::make_unique>(rows_padded, cols_padded, 0.0); for (int64 row = 0; row < rows; ++row) { for (int64 col = 0; col < cols; ++col) { (*array)(row, col) = col + (row * 1000.0f); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index b04a3b105ca..24d0325929b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -613,7 +613,7 @@ template std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, uint32 seed) { - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); PseudorandomGenerator generator(min_value, max_value, seed); for (int y = 0; y < rows; ++y) { for (int x = 0; x < cols; ++x) { diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 7b6bbc4f571..38b6da4fa96 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -88,9 +88,9 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { XLA_TEST_F(ConvolutionDimensionNumbersTest, TwoConvsWithDifferentDimensionNumbers) { - auto input_array = MakeUnique>(2, 3, 5, 5); + auto input_array = absl::make_unique>(2, 3, 5, 5); input_array->FillWithMultiples(0.1); - auto weight_array = MakeUnique>(4, 3, 1, 1); + auto weight_array = absl::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = client_ diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 689928aee44..40658c3b775 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -71,16 +71,16 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { const int kKernelSizeY = 2; const int kOutputActivationSizeZ = 256; const int kMiniBatchSize = 4; - auto alhs = - MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, - kInputActivationSizeY, kInputActivationSizeX); + auto alhs = absl::make_unique>( + kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY, + kInputActivationSizeX); alhs->FillWithMultiples(static_cast(1.0f)); ASSERT_EQ(3, alhs->width()); ASSERT_EQ(3, alhs->height()); - auto arhs = - MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, - kKernelSizeY, kKernelSizeX); + auto arhs = absl::make_unique>(kOutputActivationSizeZ, + kInputActivationSizeZ, + kKernelSizeY, kKernelSizeX); Array2D rhs_raster({ {1.0f, 0.0f}, // row 0 {0.0f, 0.0f}, // row 1 diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 5ef273e5a26..50a9ebc1e99 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 13c777835eb..6f7fc0e6e52 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 792be0d3fcd..341124170a5 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -22,13 +22,13 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index b6b8c43bd91..2167d4240e4 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -21,9 +21,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -95,11 +95,11 @@ HloTestBase::HloTestBase(se::Platform* test_platform, bool allow_mixed_precision_in_hlo_verifier) : test_runner_(test_platform), reference_runner_(reference_platform) { hlo_verifier_ = - MakeUnique(allow_mixed_precision_in_hlo_verifier); + absl::make_unique(allow_mixed_precision_in_hlo_verifier); } std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - return MakeUnique(name, GetModuleConfigForTest()); + return absl::make_unique(name, GetModuleConfigForTest()); } /* static */ diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index ad1f5b9eed8..a509ee32078 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -25,7 +26,7 @@ limitations under the License. namespace xla { HloVerifiedTestBase::HloVerifiedTestBase() - : shape_verifier_(MakeUnique()) {} + : shape_verifier_(absl::make_unique()) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index e719da54d45..8d658695576 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" @@ -125,7 +126,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), config); + return absl::make_unique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index eaddf756dbc..948b60061e2 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index da8c42d4653..b6035a21a67 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -133,7 +133,7 @@ class TestLinspaceMaxParametric float from = -128.0, to = 256.0; std::unique_ptr> alhs = MakeLinspaceArray2D(from, to, rows, cols); - auto arhs = MakeUnique>(rows, cols, static_cast(1.0f)); + auto arhs = absl::make_unique>(rows, cols, static_cast(1.0f)); XlaBuilder builder( tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index eb06b115daa..cadf1c5523a 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index ca21b0b2ba5..cbeddffacfa 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -140,7 +140,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { XlaBuilder b(TestName()); - auto input = MakeUnique>(1, 1, 3, 2); + auto input = absl::make_unique>(1, 1, 3, 2); Array2D input_xy({ {1.0f, 2.0f}, // row 0 {3.0f, 4.0f}, // row 1 @@ -151,7 +151,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(1.5); (*expected)(1, 0, 0, 0) = 1.0f; (*expected)(1, 0, 0, 1) = 2.0f; @@ -171,7 +171,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { AddParam(*LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(8, 5, 1, 1); + auto expected = absl::make_unique>(8, 5, 1, 1); expected->Fill(pad_value); (*expected)(1, 0, 0, 0) = 1.0f; (*expected)(1, 2, 0, 0) = 2.0f; @@ -269,7 +269,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { XLA_TEST_F(PadTest, Pad4DU8Array) { XlaBuilder b(TestName()); - auto input = MakeUnique>(1, 1, 3, 2); + auto input = absl::make_unique>(1, 1, 3, 2); Array2D input_xy({ {1, 2}, // row 0 {3, 4}, // row 1 @@ -280,7 +280,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { Pad(AddParam(*input, &b), ConstantR0(&b, 35), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(35); (*expected)(1, 0, 0, 0) = 1; (*expected)(1, 0, 0, 1) = 2; @@ -301,13 +301,13 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { Pad(input, ConstantR0(&b, false), r4_padding_on_dim0_dim1_); // For the same reason, use Select to convert boolean values to int32. - auto zeros = MakeUnique>(2, 3, 3, 2); - auto ones = MakeUnique>(2, 3, 3, 2); + auto zeros = absl::make_unique>(2, 3, 3, 2); + auto ones = absl::make_unique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(0); (*expected)(1, 0, 0, 0) = 1; (*expected)(1, 0, 0, 1) = 1; @@ -321,7 +321,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { XLA_TEST_P(PadTestFloat, Large2DPad) { XlaBuilder b(TestName()); - auto ones = MakeUnique>(4, 4); + auto ones = absl::make_unique>(4, 4); ones->Fill(1.0f); auto input = AddParam(*ones, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -342,7 +342,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(0.0f); auto input = AddParam(*operand, &b); @@ -368,7 +368,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { constexpr int64 low_padding = 0; int64 high_padding[2] = {5, 7}; constexpr int64 interior_padding = 0; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -395,7 +395,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {-3, 4}; constexpr int64 interior_padding = 0; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -423,7 +423,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { int64 low_padding[2] = {4, -1}; int64 high_padding[2] = {-2, -4}; int64 interior_padding[2] = {1, 2}; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -446,7 +446,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { // Regression test for b/31827337. XLA_TEST_P(PadTestFloat, ReducePad) { XlaBuilder b(TestName()); - auto ones = MakeUnique>(2, 2, 2, 2); + auto ones = absl::make_unique>(2, 2, 2, 2); ones->Fill(1.0); auto input = AddParam(*ones, &b); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index cae029fd703..09acadb2c27 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -357,7 +358,7 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - auto arg_literal = MakeUnique(shape); + auto arg_literal = absl::make_unique(shape); arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); @@ -368,7 +369,7 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - auto expected = MakeUnique(result_shape); + auto expected = absl::make_unique(result_shape); expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index f05421f8e1e..2f1d97b25d5 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -15,12 +15,13 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { @@ -130,7 +131,7 @@ StatusOr> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); switch (shape.element_type()) { case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine, @@ -383,13 +384,15 @@ StatusOr> MakeConstrainedArgument( StatusOr> MakeFakeLiteral(const Shape& shape, bool pseudo_random) { - auto engine = pseudo_random ? MakeUnique() : nullptr; + auto engine = + pseudo_random ? absl::make_unique() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } StatusOr>> MakeFakeArguments( HloModule* const module, bool pseudo_random) { - auto engine = pseudo_random ? MakeUnique() : nullptr; + auto engine = + pseudo_random ? absl::make_unique() : nullptr; return MakeFakeArguments(module, engine.get()); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 3a8ad80ed16..1aca1d8ef7e 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -20,9 +20,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 97bbf80aff8..c101cd2d201 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -504,7 +505,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = MakeUnique(sum->shape()); + auto prod = absl::make_unique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { return sum->Get(indexes) * diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 897123d7606..7de2c39b389 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -102,7 +102,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { ShapeUtil::HumanString(shape).c_str()); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); const float fill = std::numeric_limits::quiet_NaN(); result->PopulateWithValue(fill); std::vector pieces;