diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 64ab2157580..736c6089c21 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -342,6 +342,7 @@ filegroup( "//tensorflow/tensorboard/components/tf_globals:all_files", "//tensorflow/tensorboard/components/tf_globals_d3v4:all_files", "//tensorflow/tensorboard/components/tf_graph_common:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader:all_files", "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files", "//tensorflow/tensorboard/components/tf_histogram_dashboard/demo:all_files", "//tensorflow/tensorboard/components/tf_image_dashboard:all_files", diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index c52a56b6428..c12005a4cab 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -73,7 +73,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 4b5534c1648..3955cabedf5 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -274,7 +274,8 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, +Status ConvertGraphToXla(xla::CompileOnlyClient* client, + std::unique_ptr graph, xla::Computation* computation, bool* has_context_arg) { // Create a device and context to convert the graph into an XLA computation. XlaOpRegistry::RegisterCompilationKernels(); @@ -333,7 +334,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, } // Compiles the XLA computation into executable code. -Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, +Status CompileXla(xla::CompileOnlyClient* client, + const xla::Computation& computation, const xla::cpu::CpuAotCompilationOptions& aot_opts, CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. @@ -350,7 +352,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::LocalClient::AheadOfTimeComputationInstance instance; + xla::CompileOnlyClient::AotComputationInstance instance; instance.computation = &computation; instance.argument_layouts = std::move(arg_layouts); instance.result_layout = &pshape->result(); @@ -365,7 +367,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, std::move(aot_or.ValueOrDie().back())); compile_result->entry_point = aot_opts.entry_point_name(); compile_result->pointer_size = - xla::LocalClient::PointerSizeForTriple(aot_opts.triple()); + xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple()); return Status::OK(); } @@ -394,8 +396,9 @@ Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, namespace gpu = perftools::gputools; gpu::Platform* cpu_platform = gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); - xla::LocalClient* client = - xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie(); + xla::CompileOnlyClient* client = + xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) + .ValueOrDie(); xla::Computation computation; TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation, &compile_result->has_context_arg)); diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 3e9dfe2a922..2d96128e259 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -99,6 +99,26 @@ cc_library( ], ) +cc_library( + name = "compile_only_client", + srcs = ["compile_only_client.cc"], + hdrs = ["compile_only_client.h"], + deps = [ + ":client", + ":computation", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:compile_only_service", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:support", + ], +) + # This target is used to instantiate the XLA service in-process and create # a client for it. cc_library( @@ -106,12 +126,14 @@ cc_library( srcs = ["client_library.cc"], hdrs = ["client_library.h"], deps = [ + ":compile_only_client", ":local_client", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 93437023bc8..eb9a7ff2acf 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -69,8 +69,8 @@ ClientLibrary::~ClientLibrary() = default; TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - auto it = client_library.instances_.find(platform->id()); - if (it != client_library.instances_.end()) { + auto it = client_library.local_instances_.find(platform->id()); + if (it != client_library.local_instances_.end()) { return it->second->client.get(); } @@ -78,13 +78,13 @@ ClientLibrary::~ClientLibrary() = default; service_options.set_platform(platform); service_options.set_number_of_replicas(replica_count); - std::unique_ptr instance = MakeUnique(); + auto instance = MakeUnique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); instance->client = MakeUnique(instance->service.get()); LocalClient* cl = instance->client.get(); - client_library.instances_.insert( + client_library.local_instances_.insert( std::make_pair(platform->id(), std::move(instance))); return cl; } @@ -99,9 +99,35 @@ ClientLibrary::~ClientLibrary() = default; perftools::gputools::Platform* platform) { ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); - auto it = client_library.instances_.find(platform->id()); - CHECK(it != client_library.instances_.end()); + auto it = client_library.local_instances_.find(platform->id()); + CHECK(it != client_library.local_instances_.end()); return it->second->service.get(); } +/* static */ StatusOr +ClientLibrary::GetOrCreateCompileOnlyClient( + perftools::gputools::Platform* platform) { + ClientLibrary& client_library = Singleton(); + tensorflow::mutex_lock lock(client_library.service_mutex_); + + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + auto it = client_library.compile_only_instances_.find(platform->id()); + if (it != client_library.compile_only_instances_.end()) { + return it->second->client.get(); + } + + auto instance = MakeUnique(); + TF_ASSIGN_OR_RETURN(instance->service, + CompileOnlyService::NewService(platform)); + instance->client = MakeUnique(instance->service.get()); + CompileOnlyClient* cl = instance->client.get(); + + client_library.compile_only_instances_.insert( + std::make_pair(platform->id(), std::move(instance))); + return cl; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 2bc319f9333..49f45414378 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -26,7 +26,9 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -76,6 +78,13 @@ class ClientLibrary { // access user computations from client. static LocalService* GetXlaService(perftools::gputools::Platform* platform); + // Singleton constructor-or-accessor for compile-only clients. Arguments: + // + // platform : The platform the underlying XLA service should target. If + // null then default platform is used. + static StatusOr GetOrCreateCompileOnlyClient( + perftools::gputools::Platform* platform = nullptr); + private: // Returns the singleton instance of ClientLibrary. static ClientLibrary& Singleton(); @@ -90,10 +99,21 @@ class ClientLibrary { std::unique_ptr client; }; + struct CompileOnlyInstance { + // Service that is wrapped by the singleton client object. + std::unique_ptr service; + // Singleton client object. + std::unique_ptr client; + }; + tensorflow::mutex service_mutex_; // Guards the singleton creation state. std::unordered_map> - instances_ GUARDED_BY(service_mutex_); + local_instances_ GUARDED_BY(service_mutex_); + + std::unordered_map> + compile_only_instances_ GUARDED_BY(service_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); }; diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc new file mode 100644 index 00000000000..2ff6f0b300f --- /dev/null +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/compile_only_client.h" + +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +StatusOr>> +CompileOnlyClient::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector service_instances; + service_instances.reserve(computations.size()); + for (const AotComputationInstance& instance : computations) { + service_instances.push_back({}); + CompileOnlyService::AotComputationInstance& service_instance = + service_instances.back(); + TF_RET_CHECK(instance.computation != nullptr); + service_instance.computation = instance.computation->handle(); + service_instance.argument_layouts = instance.argument_layouts; + service_instance.result_layout = instance.result_layout; + } + return compiler_service_->CompileAheadOfTime(service_instances, options); +} + +int64 CompileOnlyClient::PointerSizeForTriple( + tensorflow::StringPiece target_triple) { + llvm::Triple triple( + llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); + if (triple.isArch64Bit()) { + return 8; + } else if (triple.isArch32Bit()) { + return 4; + } else { + CHECK(triple.isArch16Bit()); + return 2; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h new file mode 100644 index 00000000000..59000487113 --- /dev/null +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/service/compile_only_service.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// An XLA Client specialization for doing ahead-of-time compilation. This does +// not require (or attempt to instantiate) an execution-capable backend for the +// relevant platform. +class CompileOnlyClient : public Client { + public: + explicit CompileOnlyClient(CompileOnlyService* service) + : Client(service), compiler_service_(service) {} + + CompileOnlyClient(const CompileOnlyClient&) = delete; + void operator=(const CompileOnlyClient&) = delete; + + // A description of a computation to compile using CompileAheadOfTime. + struct AotComputationInstance { + const Computation* computation; + // Inform the compiler of the expected layout for arguments. + std::vector argument_layouts; + // Specifies the expected result layout. + const Shape* result_layout; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. The |options| parameter describes + // the target for which the compiler should emit code. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options); + + // Returns the size of a pointer in bytes for a given triple. + static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + + private: + CompileOnlyService* compiler_service_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ diff --git a/tensorflow/compiler/xla/client/global_data.h b/tensorflow/compiler/xla/client/global_data.h index eb11d91034b..b7929357d06 100644 --- a/tensorflow/compiler/xla/client/global_data.h +++ b/tensorflow/compiler/xla/client/global_data.h @@ -23,13 +23,15 @@ limitations under the License. namespace xla { -// Wraps a GlobalDataHandle with a lifetime. +// A GlobalData object represents a globally-accessible allocation of +// data in the associated XLA service. class GlobalData { public: // Gives ownership of the global data handle to this object. GlobalData(ServiceInterface* parent, GlobalDataHandle handle); - // Unregisters the wrapped handle. + // Unregisters the wrapped handle, which causes the service to + // deallocate the associated data. ~GlobalData(); const GlobalDataHandle& handle() const { return handle_; } diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index bfd14bc1c01..452462287cf 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -176,10 +176,10 @@ StatusOr> LocalExecutable::Run( TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); ExecutableRunOptions actual_options = options; - Backend::StreamPtr stream; if (options.stream() == nullptr) { TF_ASSIGN_OR_RETURN( - stream, BorrowStreamForDevice(options.device_ordinal(), backend_)); + Backend::StreamPtr stream, + BorrowStreamForDevice(options.device_ordinal(), backend_)); actual_options.set_stream(stream.get()); } if (options.allocator() == nullptr) { @@ -261,38 +261,6 @@ tensorflow::Status LocalClient::ResolveArguments( argument_ptrs); } -StatusOr>> -LocalClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AheadOfTimeComputationInstance& instance : computations) { - service_instances.push_back({}); - LocalService::AheadOfTimeComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return local_service_->CompileAheadOfTime(service_instances, options); -} - -int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) { - llvm::Triple triple( - llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); - if (triple.isArch64Bit()) { - return 8; - } else if (triple.isArch32Bit()) { - return 4; - } else { - CHECK(triple.isArch16Bit()); - return 2; - } -} - se::Platform* LocalClient::platform() const { return local_service_->backend().platform(); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 2c467efcea1..94d56106398 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -148,7 +148,7 @@ class LocalExecutable { const ExecutableBuildOptions& build_options_; }; -// An XLA service client object for use when the client and service run in +// An XLA Client specialization for use when the client and service run in // the same process. class LocalClient : public Client { public: @@ -182,30 +182,6 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AheadOfTimeComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - // - // TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its - // own library. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options); - - // Returns the size of a pointer in bytes for a given triple. - static int64 PointerSizeForTriple(tensorflow::StringPiece triple); - // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 03c9e2c9d75..3b3e004864a 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include +#include #include #include #include @@ -308,37 +309,16 @@ template /* static */ std::unique_ptr LiteralUtil::Relayout( const Literal& original, const Layout& layout) { - // Note: if this were a performance bottleneck, we avoid cloning and just make - // an uninitialized array instead, since all values are clobbered below. std::unique_ptr result = CloneToUnique(original); *result->mutable_shape()->mutable_layout() = layout; - const PrimitiveType primitive_type = original.shape().element_type(); - switch (primitive_type) { - case F32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, float value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - case S32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, int32 value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - case U32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, uint32 value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(primitive_type); - } + + const Shape& shape = original.shape(); + std::vector base(ShapeUtil::Rank(shape), 0); + std::vector copy_size(shape.dimensions().begin(), + shape.dimensions().end()); + + TF_CHECK_OK(Copy(original, base, result.get(), base, copy_size)); + return result; } /* static */ StatusOr> LiteralUtil::Reshape( @@ -346,25 +326,19 @@ template if (ShapeUtil::IsTuple(input.shape())) { return InvalidArgument("Reshape does not support tuples."); } - + std::unique_ptr output; if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) { - return Unimplemented( - "Input shape must have a monotonic layout where dimension 0 is major, " - "was: %s", - LayoutUtil::HumanString(input.shape().layout()).c_str()); + std::vector minor_to_major(ShapeUtil::Rank(input.shape())); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), + static_cast(0)); + output = Relayout(input, LayoutUtil::MakeLayout(minor_to_major)); + } else { + output = CloneToUnique(input); } - std::vector layout(dimensions.size()); - std::iota(layout.rbegin(), layout.rend(), 0); - // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - std::unique_ptr output = CloneToUnique(input); - output->clear_shape(); - output->mutable_shape()->set_element_type(input.shape().element_type()); - for (int64 dimension : dimensions) { - output->mutable_shape()->add_dimensions(dimension); - } - *output->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(layout); + *output->mutable_shape() = + ShapeUtil::MakeShape(input.shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(input.shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -378,73 +352,42 @@ template return std::move(output); } -namespace { - -template -void TransposeLiteralInternal(const Literal& original, - tensorflow::gtl::ArraySlice permutation, - Literal* result) { - std::vector new_indices(ShapeUtil::Rank(original.shape())); - LiteralUtil::EachCell( - original, [&](tensorflow::gtl::ArraySlice indices, T value) { - for (int64 i = 0; i < indices.size(); ++i) { - new_indices[i] = indices[permutation[i]]; - } - LiteralUtil::Set(result, new_indices, value); - }); -} -} // namespace - /* static */ std::unique_ptr LiteralUtil::Transpose( const Literal& original, tensorflow::gtl::ArraySlice permutation) { CHECK(!ShapeUtil::IsTuple(original.shape())) - << "tuple is not supported for transpose"; - std::vector dimension_numbers(ShapeUtil::Rank(original.shape())); - std::iota(dimension_numbers.begin(), dimension_numbers.end(), 0); - CHECK(std::is_permutation(permutation.begin(), permutation.end(), - dimension_numbers.begin())) - << "given permutation is not a permutation of dimension numbers"; - std::vector new_dimension_sizes; - for (const int64 dim : permutation) { - new_dimension_sizes.push_back(original.shape().dimensions(dim)); - } - const auto result_shape = ShapeUtil::MakeShape( - original.shape().element_type(), new_dimension_sizes); - std::unique_ptr result = CloneToUnique(original); - *result->mutable_shape() = result_shape; - const PrimitiveType primitive_type = original.shape().element_type(); - switch (primitive_type) { - case F32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case F64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case PRED: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S8: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U8: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(primitive_type); + << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(original.shape()))) + << "Given permutation is not a permutation of dimension numbers"; + // To transpose the array, we just permute the dimensions and layout, and + // do a straight memory copy of the raw data set. + // This is considerably faster than iterating over every array element using + // the EachCell<>() and Set<>() APIs. + std::vector inverse_permutation = InversePermutation(permutation); + Shape shape = + ShapeUtil::PermuteDimensions(inverse_permutation, original.shape()); + // Replace the layout with one affine to the original shape, such that a + // transpose operation can be performed by leaving the flat values + // representation intact. + // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. + // The shape with affine layout resulting from that operation will be + // F32[8,11]{0,1}, since it leave the original most minor (the 8 sized), the + // most minor. + // Essentially, given MinMaj(Di) the position of the Di dimension within the + // minor to major vector, and given T(Di) the index that the original Di + // dimension has within the transposed array, a layout is affine if + // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major + // vector of the affine layout. + Layout* layout = shape.mutable_layout(); + layout->clear_minor_to_major(); + for (auto index : original.shape().layout().minor_to_major()) { + layout->add_minor_to_major(inverse_permutation[index]); } + std::unique_ptr new_literal = CreateFromShape(shape); + DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + ShapeUtil::ByteSizeOf(original.shape())); + std::memcpy(MutableInternalData(new_literal.get()), InternalData(original), + ShapeUtil::ByteSizeOf(original.shape())); + return new_literal; } /* static */ std::unique_ptr LiteralUtil::Slice( @@ -793,47 +736,14 @@ void TransposeLiteralInternal(const Literal& original, const Literal& literal, const std::function indices, const string& value)>& per_cell) { - if (ShapeUtil::Rank(literal.shape()) == 1) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - per_cell({i0}, GetAsString(literal, {i0})); - } + if (ShapeUtil::HasZeroElements(literal.shape())) { return; } - - if (ShapeUtil::Rank(literal.shape()) == 2) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - per_cell({i0, i1}, GetAsString(literal, {i0, i1})); - } - } - return; - } - - if (ShapeUtil::Rank(literal.shape()) == 3) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - per_cell({i0, i1, i2}, GetAsString(literal, {i0, i1, i2})); - } - } - } - return; - } - - if (ShapeUtil::Rank(literal.shape()) == 4) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - for (int64 i3 = 0; i3 < literal.shape().dimensions(3); ++i3) { - per_cell({i0, i1, i2, i3}, GetAsString(literal, {i0, i1, i2, i3})); - } - } - } - } - return; - } - - LOG(FATAL) << "unhandled rank: " << ShapeUtil::Rank(literal.shape()); + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + literal.shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(literal, indices)); + } while (IndexUtil::BumpIndices(literal.shape(), &indices)); } namespace { diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 3a6d21979e7..8bdf8daff55 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -239,6 +239,11 @@ class LiteralUtil { // Clones literal into an owned unique_ptr version. static std::unique_ptr CloneToUnique(const Literal& literal); + // Returns the linear index of the given index within the literal's + // element_type repeated field. + static int64 LinearIndex(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index); + // Gets or sets an element in the literal at the given index. The index is // CHECKed against the dimension sizes. template @@ -427,11 +432,6 @@ class LiteralUtil { "Cannot map native type to primitive type."); } - // Returns the linear index of the given index within the literal's - // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); - // Internal template helper for the Copy() API, matching its arguments one by // one. // diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index dd4d820babe..0f214d7f9ea 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -469,6 +469,26 @@ TEST_F(LiteralUtilTest, ReshapeR4) { EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); } +TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { + // clang-format off + // F32[1x3x2x4] + auto original = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0minor_); + // F32[1x3x4x2] + auto expected = LiteralUtil::CreateR3WithLayout({ + {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, + {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, + {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, + }, layout_r3_dim0major_); + // clang-format on + auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + + EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); +} + TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{}); @@ -659,15 +679,15 @@ TEST_F(LiteralUtilTest, Copy) { primitive_util::NativeToPrimitiveType(), dimensions, layout); auto blank = LiteralUtil::CreateFromShape(shape); auto source = LiteralUtil::CreateFromShape(shape); - const int64 sbase[] = {0, 0, 0, 0}; - const int64 incr[] = {1, 1, 1, 1}; + const int64 zero_base[] = {0, 0, 0, 0}; + const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; auto init_proc = [&](const std::vector& indexes) { LiteralUtil::Set(source.get(), indexes, ++seqnr); return true; }; - ShapeUtil::ForEachIndex(source->shape(), sbase, dimensions, incr, + ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, init_proc); const int64 src_base[] = {3, 1, 5, 7}; @@ -691,7 +711,7 @@ TEST_F(LiteralUtilTest, Copy) { bval == LiteralUtil::Get(*source, source_indexes)); return matched; }; - ShapeUtil::ForEachIndex(source->shape(), sbase, copy_size, incr, + ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, check_proc); EXPECT_TRUE(matched); } @@ -710,5 +730,43 @@ TEST_F(LiteralUtilTest, CopyScalars) { EXPECT_EQ(LiteralUtil::Get(*vect, {4}), 17); } +TEST_F(LiteralUtilTest, Populate) { + struct PopulateData { + std::vector dimensions; + std::vector layout; + } populate_data[] = { + {{}, {}}, + {{16}, {0}}, + {{4, 16}, {1, 0}}, + {{21, 12}, {0, 1}}, + {{6, 11, 17}, {2, 0, 1}}, + {{6, 11, 5, 17}, {3, 2, 0, 1}}, + }; + for (const auto& data : populate_data) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), data.dimensions, + data.layout); + auto literal = LiteralUtil::CreateFromShape(shape); + auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { + // Offsets from linear index just to avoid R0 literals to be initialized + // with zero. + return LiteralUtil::LinearIndex(*literal, indexes) + 17; + }; + TF_EXPECT_OK(LiteralUtil::Populate(literal.get(), generator)); + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + bool matched = true; + auto check_function = [&](const std::vector& indexes) { + auto value = LiteralUtil::Get(*literal, indexes); + matched = matched && (value == generator(indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + check_function); + EXPECT_TRUE(matched); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 65b49d99cca..dfb27dd6487 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -406,6 +406,27 @@ cc_library( ], ) +cc_library( + name = "compile_only_service", + srcs = ["compile_only_service.cc"], + hdrs = ["compile_only_service.h"], + deps = [ + ":backend", + ":compiler", + ":computation_layout", + ":computation_tracker", + ":platform_util", + ":service", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "cpu_plugin", deps = [ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc new file mode 100644 index 00000000000..ac1906c88c4 --- /dev/null +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -0,0 +1,131 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compile_only_service.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/* static */ StatusOr> +CompileOnlyService::NewService(perftools::gputools::Platform* platform) { + ServiceOptions default_options; + default_options.set_platform(platform); + return NewService(default_options); +} + +/* static */ StatusOr> +CompileOnlyService::NewService(const ServiceOptions& options) { + perftools::gputools::Platform* platform = options.platform(); + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, + CreateComputeConstantBackend()); + std::unique_ptr service( + new CompileOnlyService(compiler, std::move(compute_constant_backend))); + return std::move(service); +} + +CompileOnlyService::CompileOnlyService( + Compiler* compiler, std::unique_ptr compute_constant_backend) + : Service(/*backend=*/nullptr, std::move(compute_constant_backend)), + compiler_(compiler) { + runs_in_client_process_ = true; +} + +StatusOr>> +CompileOnlyService::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector> hlo_modules; + std::vector> module_configs; + for (const AotComputationInstance& instance : computations) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(instance.computation)); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + // Dump computation proto state if flag is set. + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + const string& directory_path = flags->xla_dump_computations_to; + if (!directory_path.empty()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr session_module, + computation_tracker_.SnapshotComputation(versioned_handle.handle)); + string filename = tensorflow::strings::StrCat( + "computation_", versioned_handle.handle.handle(), "__", + session_module->entry().name(), "__version_", + versioned_handle.version); + TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, + *session_module)); + } + + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + computation_tracker_.BuildHloModule( + versioned_handle, + /*include_unreachable_instructions=*/true)); + hlo_modules.push_back(std::move(hlo_module)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + module_configs.push_back(MakeUnique(*program_shape)); + HloModuleConfig* module_config = module_configs.back().get(); + auto* computation_layout = + module_config->mutable_entry_computation_layout(); + if (flags->xla_hlo_profile) { + module_config->enable_hlo_profiling(true); + } + for (int i = 0; i < instance.argument_layouts.size(); ++i) { + const Shape& argument_layout = *instance.argument_layouts[i]; + if (ShapeUtil::IsTuple(argument_layout)) { + return Unimplemented("tuple arguments not supported yet"); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + argument_layout)); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + *instance.result_layout)); + } + + return compiler_->CompileAheadOfTime(std::move(hlo_modules), + std::move(module_configs), + MakeHloDumper(), options); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h new file mode 100644 index 00000000000..6dae49e3e1a --- /dev/null +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// An XLA Service specialization for ahead-of-time compilation. This only +// instantiates a Compiler object for the relevant platform; it does not +// instantiate or require an execution backend. +class CompileOnlyService : public Service { + public: + // Factory for creating a CompileOnlyService. The parameter platform is the + // platform that the service should target. If platform is null then the + // default platform is used. + static StatusOr> NewService( + perftools::gputools::Platform* platform); + static StatusOr> NewService( + const ServiceOptions& options); + + // A description of a computation to compile using CompileAheadOfTime. + struct AotComputationInstance { + ComputationHandle computation; + std::vector argument_layouts; + const Shape* result_layout = nullptr; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. See + // |CompileOnlyClient::CompileAheadOfTime| for additional details. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& Options); + + // Override Service methods that require or imply the existence of an + // execute backend. Note that this does not include TransferToClient and + // TransferToClientInProcess, as computing contants produces global data + // that we may wish to transfer. + tensorflow::Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status GetDeviceHandles( + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { + return Unimplemented("CompileOnlyService does not support devices."); + } + tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status WaitForExecution( + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status TransferToServer( + const TransferToServerRequest* arg, + TransferToServerResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferToInfeed( + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferToServerInProcess( + const TransferToServerInProcessRequest* arg, + TransferToServerInProcessResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { + return Unimplemented("CompileOnlyService does not support devices."); + } + + private: + explicit CompileOnlyService( + Compiler* compiler, std::unique_ptr compute_constant_backend); + CompileOnlyService(const CompileOnlyService&) = delete; + void operator=(const CompileOnlyService&) = delete; + + // The compiler for the target platform. This is included in place of + // the Service::execute_backend_'s compiler, since execute_backend_ is a + // nullptr in CompileOnlyService. + Compiler* compiler_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index dc3a289a71b..af5cf8ca4b6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -188,41 +188,52 @@ tensorflow::Status PrepareHloModuleForIrEmitting( return pipeline.Run(hlo_module).status(); } -// Invokes the ptxas tool on the given PTX string, and dumps its output. -void DumpPtxasInfo(const string& ptx) { +// Invokes the ptxas tool on the given PTX string, and stores the resulting +// SASS in *cubin. If -v 2 or greater, runs ptxas with -v and dumps the +// resulting stderr (which contains register allocation info, etc.) +// to VLOG(2). If ptxas binary is not found *sass is set to "". +Status CompilePTX(const string& ptx, int cc_major, int cc_minor, + string* cubin) { + *cubin = ""; + const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); + // Do not log PTX stats if ptxas is not found at the given path. - if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) { - LOG(WARNING) - << "Failed to dump PTX stats because ptxas is not found at path \"" - << ptxas_path << "\"."; - return; - } + LOG(INFO) << "Invoking ptxas at path \"" << ptxas_path << "\"."; + TF_RETURN_IF_ERROR(tensorflow::Env::Default()->FileExists(ptxas_path)); // Write `ptx` into a temporary file. char tempdir_template[] = "/tmp/ptxXXXXXX"; char* tempdir_name = mkdtemp(tempdir_template); CHECK_NOTNULL(tempdir_name); string ptx_path = tensorflow::io::JoinPath(tempdir_name, "ptx"); + TF_CHECK_OK( tensorflow::WriteStringToFile(tensorflow::Env::Default(), ptx_path, ptx)); LOG(INFO) << "ptx file written to: " << ptx_path; // Invoke ptxas and collect its output. - tensorflow::SubProcess ptxas_info_dumper; - ptxas_info_dumper.SetProgram(ptxas_path, {ptxas_path, ptx_path, "-o", - "/dev/null", "-v", "-arch=sm_35"}); - ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, - tensorflow::ACTION_PIPE); - CHECK(ptxas_info_dumper.Start()); - string stderr_output; - int exit_status = ptxas_info_dumper.Communicate( - /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); - XLA_LOG_LINES(tensorflow::INFO, stderr_output); - if (exit_status != 0) { - LOG(FATAL) << "Invalid PTX. See the error message above for reasons."; + tensorflow::SubProcess ptxas_info; + string arch = tensorflow::strings::StrCat("sm_", cc_major, cc_minor); + string cubin_path = tensorflow::io::JoinPath(tempdir_name, "cubin"); + + if (VLOG_IS_ON(2)) { + ptxas_info.SetProgram(ptxas_path, {ptxas_path, "-v", "-o", cubin_path, + "-arch", arch, ptx_path}); + } else { + ptxas_info.SetProgram( + ptxas_path, {ptxas_path, "-o", cubin_path, "-arch", arch, ptx_path}); } + ptxas_info.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); + CHECK(ptxas_info.Start()); + string stderr_output; + int ptxas_exit_status = ptxas_info.Communicate( + /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); + + TF_RET_CHECK(ptxas_exit_status == 0); + return tensorflow::ReadFileToString(tensorflow::Env::Default(), cubin_path, + cubin); } } // namespace @@ -298,10 +309,14 @@ StatusOr> GpuCompiler::Compile( // Reserve space for the PTX to be generated for this module. string* ptx; + string* cubin; { tensorflow::mutex_lock lock(mutex_); generated_ptxes_.emplace_back(MakeUnique()); ptx = generated_ptxes_.back().get(); + + generated_cubins_.emplace_back(MakeUnique()); + cubin = generated_cubins_.back().get(); } int cc_major, cc_minor; if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, @@ -318,9 +333,6 @@ StatusOr> GpuCompiler::Compile( XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); VLOG(2) << "PTX:"; XLA_VLOG_LINES(2, *ptx); - if (VLOG_IS_ON(2)) { - DumpPtxasInfo(*ptx); - } auto thunk_schedule = MakeUnique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -328,9 +340,13 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "Printing the thunk schedule..."; XLA_VLOG_LINES(2, thunk_schedule->ToString()); + TF_RET_CHECK(CompilePTX(*ptx, cc_major, cc_minor, cubin).ok()); + auto* gpu_executable = - new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(hlo_module), + new GpuExecutable(*cubin, *ptx, {cc_major, cc_minor}, + std::move(thunk_schedule), std::move(hlo_module), std::move(module_config), std::move(buffer_assignment)); + if (flags->xla_gpu_embed_ir) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 22f492b4229..99c7ba51999 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -71,6 +71,7 @@ class GpuCompiler : public Compiler { // StreamExecutor (b/24776264). tensorflow::mutex mutex_; std::vector> generated_ptxes_ GUARDED_BY(mutex_); + std::vector> generated_cubins_ GUARDED_BY(mutex_); // The size in bytes of a pointer. Used for computing ShapeSizeBytes. int64 pointer_size_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 32f0368b4bc..b4b788162f8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -107,13 +107,17 @@ class HloExecutionProfiler { // Implementation note: HLO profiling is always enabled for GPU executables, // since we can use timers around thunks. -GpuExecutable::GpuExecutable(tensorflow::StringPiece ptx, +GpuExecutable::GpuExecutable(tensorflow::StringPiece cubin, + tensorflow::StringPiece ptx, + std::pair compute_capability, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr module_config, std::unique_ptr assignment) : Executable(std::move(hlo_module), std::move(module_config)), + cubin_(cubin), ptx_(ptx), + compute_capability_(compute_capability), thunk_schedule_(std::move(thunk_schedule)), assignment_(std::move(assignment)) {} @@ -186,6 +190,13 @@ StatusOr GpuExecutable::ExecuteOnStream( // false. TF_RET_CHECK(!module_config().has_hybrid_result()); + // Ensure the compute capability of the cubin and the stream match. + std::pair stream_compute_compatibility; + stream->parent()->GetDeviceDescription().cuda_compute_capability( + &stream_compute_compatibility.first, + &stream_compute_compatibility.second); + TF_RET_CHECK(stream_compute_compatibility == compute_capability_); + BufferAllocations::Builder buffer_allocations_builder; for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index e308de79ba5..09a92c4e4c6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -40,15 +40,17 @@ limitations under the License. namespace xla { namespace gpu { - // GPU-targeting implementation of the XLA Executable interface. // // Launches the given CUDA kernel via the StreamExecutor. -// -// This is an immutable data type after initialization, and thus thread safe. + +// GPUExecutable should eventually be updated to associate a compute +// capability with the PTX and store multiple cubins, each with their own +// associated CC's, rather than including CC as a property of GpuExecutable. class GpuExecutable : public Executable { public: - GpuExecutable(tensorflow::StringPiece ptx, + GpuExecutable(tensorflow::StringPiece cubin, tensorflow::StringPiece ptx, + std::pair compute_capability, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr module_config, @@ -62,7 +64,8 @@ class GpuExecutable : public Executable { ir_module_string_ = ir_module_string; } - // Returns the compiled PTX for the computation. + // Returns the compiled CUDA binary for the computation. + tensorflow::StringPiece cubin() const { return cubin_; } tensorflow::StringPiece ptx() const { return ptx_; } StatusOr ExecuteOnStream( @@ -104,8 +107,10 @@ class GpuExecutable : public Executable { // This string should be modified only before ExecuteOnStream. string ir_module_string_; - // The reference to the compiled PTX for the computation. - const tensorflow::StringPiece ptx_; + // The reference to the compiled PTX & CUDA binary for the computation. + tensorflow::StringPiece cubin_; + tensorflow::StringPiece ptx_; + std::pair compute_capability_; // The thunks to be invoked by this GpuExecutable. They are generated by the // IrEmitter. diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 69399e36c4c..48ccc63f3d5 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -41,13 +41,10 @@ tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { // Already initialized by another thread. return tensorflow::Status::OK(); } - loader_spec_.reset(new se::MultiKernelLoaderSpec(io_buffers_.size() + 1)); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + + tensorflow::StringPiece cubin = executable.cubin(); + loader_spec_->AddCudaCubinInMemory(cubin.data(), kernel_name_); return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 21d93a1f27f..a56225da156 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -195,7 +195,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape)); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; bool matched = true; diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 17d7b97b21b..6947c5d2e1d 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -128,70 +128,6 @@ StatusOr LocalService::AllocateBufferOnDevice( allocation_size)); } -StatusOr>> -LocalService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - std::vector> module_configs; - for (const AheadOfTimeComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - // Dump computation proto state if flag is set. - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - const string& directory_path = flags->xla_dump_computations_to; - if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, - /*include_unreachable_instructions=*/true)); - hlo_modules.push_back(std::move(hlo_module)); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - module_configs.push_back(MakeUnique(*program_shape)); - HloModuleConfig* module_config = module_configs.back().get(); - auto* computation_layout = - module_config->mutable_entry_computation_layout(); - if (flags->xla_hlo_profile) { - module_config->enable_hlo_profiling(true); - } - for (int i = 0; i < instance.argument_layouts.size(); ++i) { - const Shape& argument_layout = *instance.argument_layouts[i]; - if (ShapeUtil::IsTuple(argument_layout)) { - return Unimplemented("tuple arguments not supported yet"); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - argument_layout)); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - *instance.result_layout)); - } - - return execute_backend_->compiler()->CompileAheadOfTime( - std::move(hlo_modules), std::move(module_configs), MakeHloDumper(), - options); -} - StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index df27f0a7a60..a1a2ef98e95 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -59,22 +59,6 @@ class LocalService : public Service { const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy); - // A description of a computation to compile using CompileAheadOfTime. - struct AheadOfTimeComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |LocalClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& Options); - // Builds an Executable with the given argument layouts and options. If // result_layout is non-null, then the executable is compiled to produce a // result of the given layout. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 451bb8c7ead..c001e705deb 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -180,20 +180,24 @@ Service::Service(std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) : execute_backend_(std::move(execute_backend)), compute_constant_backend_(std::move(compute_constant_backend)) { - LOG(INFO) << Printf( - "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name().c_str()); - for (int i = 0; i < execute_backend_->device_count(); ++i) { - if (execute_backend_->device_ordinal_supported(i)) { - se::StreamExecutor* executor = - execute_backend_->stream_executor(i).ValueOrDie(); - const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, - description.name().c_str(), - description.platform_version().c_str()); - } else { - LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + if (execute_backend_) { + LOG(INFO) << Printf( + "XLA service %p executing computations on platform %s. Devices:", this, + execute_backend_->platform()->Name().c_str()); + for (int i = 0; i < execute_backend_->device_count(); ++i) { + if (execute_backend_->device_ordinal_supported(i)) { + se::StreamExecutor* executor = + execute_backend_->stream_executor(i).ValueOrDie(); + const auto& description = executor->GetDeviceDescription(); + LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, + description.name().c_str(), + description.platform_version().c_str()); + } else { + LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + } } + } else { + VLOG(1) << "XLA compile-only service constructed"; } } @@ -286,7 +290,7 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, Backend* backend) { auto module_config = MakeUnique(program_shape); auto* computation_layout = module_config->mutable_entry_computation_layout(); @@ -326,7 +330,7 @@ StatusOr> Service::CreateModuleConfig( module_config->enable_hlo_profiling(true); } - module_config->set_replica_count(execute_backend_->Replicas().size()); + module_config->set_replica_count(backend->Replicas().size()); module_config->set_fast_math_disabled(execution_options.disable_fast_math()); module_config->set_seed(execution_options.seed()); @@ -474,7 +478,7 @@ StatusOr> Service::BuildAndCacheExecutable( std::unique_ptr executable_unique_ptr, BuildExecutable(versioned_handle, std::move(module_config), /*executable_for_compute_constant=*/false, arguments, - execute_backend_.get(), executor)); + backend, executor)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -575,15 +579,14 @@ StatusOr Service::ExecuteAndRegisterResult( perftools::gputools::DeviceMemoryBase result; if (backend->Replicas().size() == 1) { TF_ASSIGN_OR_RETURN( - result, - ExecuteOnStreamWrapper>( - executable, &run_options[0], profile, execute_backend_.get(), - [&arguments](Executable* executable, - const ServiceExecutableRunOptions* run_options, - HloExecutionProfile* hlo_execution_profile) { - return executable->ExecuteOnStream(run_options, arguments, - hlo_execution_profile); - })); + result, ExecuteOnStreamWrapper>( + executable, &run_options[0], profile, backend, + [&arguments](Executable* executable, + const ServiceExecutableRunOptions* run_options, + HloExecutionProfile* hlo_execution_profile) { + return executable->ExecuteOnStream(run_options, arguments, + hlo_execution_profile); + })); } else { std::vector< tensorflow::gtl::ArraySlice> @@ -666,7 +669,8 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // the program and the argument allocations. TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, CreateModuleConfig(*program_shape, arg_allocations, - request.execution_options())); + request.execution_options(), + execute_backend_.get())); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -751,9 +755,10 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options(), execute_backend_.get())); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -818,9 +823,10 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options(), execute_backend_.get())); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -1141,7 +1147,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options)); + CreateModuleConfig(program_shape, {}, execution_options, + compute_constant_backend_.get())); TF_ASSIGN_OR_RETURN( std::shared_ptr executable, diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 9600f6989a4..0e0e7c4e21b 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -265,11 +265,11 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice arguments, const Backend* backend, int device_ordinal); - // Create a Hlo module config foe the given program shape and arguments. + // Create a Hlo module config for the given program shape and arguments. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options); + const ExecutionOptions& execution_options, Backend* backend); // Builds an Executable for the given parameters. If // executable_for_compute_constant is true, then the executable is intended to diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b558e31ee93..ceb29aaea5b 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -728,9 +728,17 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { - new_shape.mutable_layout()->clear_minor_to_major(); + Layout* new_layout = new_shape.mutable_layout(); + new_layout->clear_minor_to_major(); for (auto index : Permute(permutation, shape.layout().minor_to_major())) { - new_shape.mutable_layout()->add_minor_to_major(index); + new_layout->add_minor_to_major(index); + } + if (shape.layout().padded_dimensions_size() > 0) { + new_layout->clear_padded_dimensions(); + for (auto dim : + Permute(permutation, shape.layout().padded_dimensions())) { + new_layout->add_padded_dimensions(dim); + } } } return new_shape; @@ -1057,7 +1065,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, DCHECK_EQ(count.size(), base.size()); const Layout& layout = shape.layout(); int64 rank = layout.minor_to_major_size(); - int64 n = 0; + // Allows handling R0 arrays, such that the visitor function will be called + // once with the proper empty indexes. + int64 n = -1; std::vector indexes(base.begin(), base.end()); while (n < rank && visitor_function(indexes)) { // Increments dimensions in minor to major order. diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 7ea83a9e956..52816dc72cc 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -42,7 +42,7 @@ xla::Computation Doubler(xla::Client* client) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - auto client = xla::ClientLibrary::LocalClientOrDie(); + auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); xla::ComputationBuilder builder(client, "aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); @@ -74,7 +74,7 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::LocalClient::AheadOfTimeComputationInstance instance{ + xla::CompileOnlyClient::AotComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index a711b5035d8..0f6bba450ec 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -153,16 +153,26 @@ string Reindent(tensorflow::StringPiece original, }); } +bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { + if (rank != permutation.size()) { + return false; + } + std::vector output(permutation.size(), -1); + for (auto index : permutation) { + CHECK_GE(index, 0); + CHECK_LT(index, rank); + output[index] = 0; + } + return std::find(output.begin(), output.end(), -1) == output.end(); +} + std::vector InversePermutation( tensorflow::gtl::ArraySlice input_permutation) { + DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { output_permutation[input_permutation[i]] = i; } - DCHECK_EQ( - 0, std::count(output_permutation.begin(), output_permutation.end(), -1)); - DCHECK(std::is_permutation(input_permutation.begin(), input_permutation.end(), - output_permutation.begin())); return output_permutation; } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 236728f417b..15a6ef404ea 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -177,6 +177,9 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); string Reindent(tensorflow::StringPiece original, tensorflow::StringPiece indentation); +// Checks whether permutation is a permutation of the [0, rank) integer range. +bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); + // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. // @@ -187,12 +190,11 @@ template