STT-tensorflow/tensorflow/compiler/xla/service/triangular_solve_expander.h
Peter Hawkins 9bb620cbf9 [XLA] Add a TriangularSolve HLO.
Previously TriangularSolve() existed in the XLA client library, but was lowered into other HLO ops immediately, preventing us from well-tuned existing BLAS implementations. This change adds a first-class HLO for TriangularSolve.

The API of TriangularSolve is chose to match the BLAS TRSM API closely.

On the CPU and interpreter backends, the TriangularSolve HLO is immediately expanded to a Call operator that runs the same computation the existing client library would have built. With some cunning, we are able to use XlaBuilder inside a lowering pass, allowing us to keep using the much simpler XlaBuilder API to express the triangular solve computation.

Adds a generic OpExpander pass superclass, and refactors GatherExpander to use it. OpExpander is used as the superclass of the new TriangularSolveExpander.

On GPU, add direct implementation of TriangularSolve in terms of the cuBlas TRSM implementation.

PiperOrigin-RevId: 232987494
2019-02-07 18:58:52 -08:00

44 lines
1.5 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
namespace xla {
class TriangularSolveExpander : public OpExpanderPass {
public:
absl::string_view name() const override {
return "triangular_solve_expander";
}
protected:
bool InstructionMatchesPattern(HloInstruction* instruction) override;
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
private:
// Mapping from op signatures to existing computations.
absl::flat_hash_map<string, HloComputation*> computation_cache_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_