216 lines
7.0 KiB
C++
216 lines
7.0 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// This header file defines common utils used by TFLite transformation
|
|
// passes to work with op attributes.
|
|
|
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|
|
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|
|
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
|
#include "mlir/IR/Function.h" // from @llvm-project
|
|
#include "mlir/IR/Location.h" // from @llvm-project
|
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
|
#include "mlir/IR/Value.h" // from @llvm-project
|
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
|
|
|
namespace mlir {
|
|
namespace TFL {
|
|
|
|
constexpr char kTFImplements[] = "tf._implements";
|
|
constexpr char kLstmCellSimple[] = "LSTMCellSimple";
|
|
constexpr char kLayerNormalizedLstmCellSimple[] =
|
|
"LayerNormalizedLstmCellSimple";
|
|
constexpr char kCoupleInputForgetGates[] = "CoupleInputForgetGates";
|
|
|
|
// A utility class that enables the conversion of the LSTMCellSimple composite
|
|
// op into a fused TFL LSTM op. The fused op is contained within a FuncOp
|
|
// that also contains other supporting ops needed to construct the operands for
|
|
// the fused op. The caller provides the containing FuncOp as input with
|
|
// arguments specifying the input, weight, projection and bias.
|
|
// The weight, projection, bias and layer norm scale all need to be
|
|
// RankedTensorType.
|
|
// This class sets the layer norm coefficients to NoneType.
|
|
class ConvertLSTMCellSimpleToFusedLSTM {
|
|
public:
|
|
explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::FuncOp fused_func_op)
|
|
: fused_func_op_(fused_func_op),
|
|
couple_input_forget_gates_(false),
|
|
builder_(fused_func_op.getBody()) {}
|
|
|
|
// not copyable.
|
|
ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) =
|
|
delete;
|
|
ConvertLSTMCellSimpleToFusedLSTM& operator=(
|
|
const ConvertLSTMCellSimpleToFusedLSTM&) = delete;
|
|
virtual ~ConvertLSTMCellSimpleToFusedLSTM() {}
|
|
|
|
virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; }
|
|
|
|
// Rewrite the func body with constructed fused lstm.
|
|
LogicalResult RewriteFunc();
|
|
|
|
int GetNumInputs() { return n_input_; }
|
|
|
|
protected:
|
|
// verify input func op arguments/attributes and initialize internal state.
|
|
virtual LogicalResult InitializeFromFuncAttributes();
|
|
virtual LogicalResult Initialize();
|
|
|
|
void UpdateFuncSignature();
|
|
void GenerateFusedOpOperands();
|
|
|
|
void SetWeightForInputToCellGate();
|
|
void SetWeightForInputToInputGate();
|
|
void SetWeightForInputToForgetGate();
|
|
void SetWeightForInputToOutputGate();
|
|
|
|
void SetWeightForRecurrentToCellGate();
|
|
void SetWeightForRecurrentToInputGate();
|
|
void SetWeightForRecurrentToForgetGate();
|
|
void SetWeightForRecurrentToOutputGate();
|
|
|
|
void SetBiasToCellGate();
|
|
void SetBiasToInputGate();
|
|
void SetBiasToForgetGate();
|
|
void SetBiasToOutputGate();
|
|
|
|
void SetProjection();
|
|
void SetProjectionBias();
|
|
|
|
void SetInputActivationState();
|
|
void SetInputCellState();
|
|
|
|
virtual void SetCellLayerNormCoefficients();
|
|
virtual void SetInputLayerNormCoefficients();
|
|
virtual void SetForgetLayerNormCoefficients();
|
|
virtual void SetOutputLayerNormCoefficients();
|
|
|
|
// specified state
|
|
FuncOp fused_func_op_;
|
|
Value input_;
|
|
Value weight_;
|
|
Value bias_;
|
|
Value projection_;
|
|
bool couple_input_forget_gates_;
|
|
|
|
// internal state
|
|
Value weight_transposed_;
|
|
Value projection_transposed_;
|
|
RankedTensorType weight_type_;
|
|
RankedTensorType projection_type_;
|
|
int num_gates_;
|
|
int n_cell_;
|
|
int n_output_;
|
|
int n_input_;
|
|
int num_cols_weight_transposed_;
|
|
int num_cols_projection_transposed_;
|
|
|
|
// input -> cifg
|
|
Value input2input_;
|
|
Value input2forget_;
|
|
Value input2cell_;
|
|
Value input2output_;
|
|
|
|
// recurrent -> cifg
|
|
Value rec2input_;
|
|
Value rec2forget_;
|
|
Value rec2cell_;
|
|
Value rec2output_;
|
|
|
|
// bias -> cifg
|
|
Value bias2input_;
|
|
Value bias2forget_;
|
|
Value bias2cell_;
|
|
Value bias2output_;
|
|
|
|
// projection
|
|
Value proj_weight_;
|
|
Value proj_bias_;
|
|
|
|
// state
|
|
Value input_activation_state_;
|
|
Value input_cell_state_;
|
|
|
|
// layer norm coefficients
|
|
Value input_layer_norm_coefficients_;
|
|
Value forget_layer_norm_coefficients_;
|
|
Value cell_layer_norm_coefficients_;
|
|
Value output_layer_norm_coefficients_;
|
|
|
|
mlir::TFL::LSTMOp lstm_;
|
|
|
|
Value none_;
|
|
SmallVector<int64_t, 1> bias_slice_shape_;
|
|
SmallVector<int64_t, 1> bias_size_values_;
|
|
SmallVector<int64_t, 2> weight_slice_shape_;
|
|
SmallVector<int64_t, 2> weight_slice_size_input_values_;
|
|
SmallVector<int64_t, 2> weight_slice_size_recurrent_values_;
|
|
OpBuilder builder_;
|
|
};
|
|
|
|
// A utility class that enables the conversion of the
|
|
// LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The
|
|
// fused op is contained within a FuncOp that also contains other supporting ops
|
|
// needed to construct the operands for the fused op. The caller provides the
|
|
// containing FuncOp as input with arguments specifying the input, weight,
|
|
// projection, bias and layer norm scale. The weight, projection, bias and
|
|
// layer norm scale all need to be RankedTensorType.
|
|
// This class overrides the layer norm coefficient setters from the base class.
|
|
class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
|
|
: public ConvertLSTMCellSimpleToFusedLSTM {
|
|
public:
|
|
explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
|
|
mlir::FuncOp fused_func_op)
|
|
: ConvertLSTMCellSimpleToFusedLSTM(fused_func_op) {}
|
|
|
|
// not copyable.
|
|
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
|
|
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
|
|
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=(
|
|
const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
|
|
~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {}
|
|
|
|
llvm::StringRef GetCompositeOpName() override {
|
|
return kLayerNormalizedLstmCellSimple;
|
|
}
|
|
|
|
protected:
|
|
LogicalResult Initialize() override;
|
|
|
|
void SetCellLayerNormCoefficients() override;
|
|
void SetInputLayerNormCoefficients() override;
|
|
void SetForgetLayerNormCoefficients() override;
|
|
void SetOutputLayerNormCoefficients() override;
|
|
|
|
private:
|
|
// specified state
|
|
Value layer_norm_scale_;
|
|
|
|
// internal state
|
|
RankedTensorType layer_norm_scale_type_;
|
|
SmallVector<int64_t, 1> layer_norm_slice_shape_;
|
|
SmallVector<int64_t, 1> layer_norm_size_values_;
|
|
};
|
|
|
|
LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder);
|
|
|
|
} // end namespace TFL
|
|
} // end namespace mlir
|
|
|
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
|