Add pass in MLIR to compress sparse tensor.
PiperOrigin-RevId: 308867556 Change-Id: I67c8f866c5d2fc46092a4592d3da6931c4a15ebe
This commit is contained in:
parent
bd1422f12e
commit
3512a5fd32
@ -417,12 +417,14 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -191,7 +191,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
|
||||
static bool IsConst(Operation* op) {
|
||||
return isa<mlir::ConstantOp>(op) || isa<mlir::TF::ConstOp>(op) ||
|
||||
isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op);
|
||||
isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op) ||
|
||||
isa<tfl::SparseConstOp>(op) || isa<tfl::SparseQConstOp>(op);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -455,6 +456,9 @@ class Translator {
|
||||
// Returns a unique name for `val`.
|
||||
std::string UniqueName(mlir::Value val);
|
||||
|
||||
BufferOffset<tflite::SparsityParameters> BuildSparsityParameters(
|
||||
const mlir::TFL::SparsityParameterAttr& s_attr);
|
||||
|
||||
ModuleOp module_;
|
||||
|
||||
tensorflow::OpOrArgNameMapper& name_mapper_;
|
||||
@ -501,9 +505,9 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
} else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
} else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
attr = cst.compressed_data();
|
||||
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
attr = cst.compressed_data();
|
||||
} else {
|
||||
return empty_buffer_;
|
||||
}
|
||||
@ -617,11 +621,12 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
|
||||
if (inst) {
|
||||
BufferOffset<tflite::SparsityParameters> s_params = 0;
|
||||
if (auto* inst = value.getDefiningOp()) {
|
||||
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
||||
// CreateSparsityParameters(cst.s_param());
|
||||
s_params = BuildSparsityParameters(cst.s_param());
|
||||
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
||||
// CreateSparsityParameters(cst.s_param());
|
||||
s_params = BuildSparsityParameters(cst.s_param());
|
||||
}
|
||||
}
|
||||
|
||||
@ -666,12 +671,12 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
/*is_variable=*/is_variable, s_params);
|
||||
} else {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable, /*sparsity=*/0,
|
||||
/*is_variable=*/is_variable, s_params,
|
||||
/*shape_signature=*/builder_.CreateVector(shape_signature));
|
||||
}
|
||||
}
|
||||
@ -1383,6 +1388,60 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
builder_.GetSize());
|
||||
}
|
||||
|
||||
BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
|
||||
const mlir::TFL::SparsityParameterAttr& s_attr) {
|
||||
const int dim_size = s_attr.dim_metadata().size();
|
||||
std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> fb_dim_metadata(
|
||||
dim_size);
|
||||
for (int i = 0; i < dim_size; i++) {
|
||||
const auto dim_metadata =
|
||||
s_attr.dim_metadata()[i].dyn_cast<mlir::TFL::DimensionMetadataAttr>();
|
||||
if (dim_metadata.format().getValue() == "DENSE") {
|
||||
fb_dim_metadata[i] =
|
||||
tflite::CreateDimensionMetadata(builder_, tflite::DimensionType_DENSE,
|
||||
dim_metadata.dense_size().getInt());
|
||||
|
||||
} else {
|
||||
auto segments = dim_metadata.segments();
|
||||
std::vector<int> vector_segments(segments.size(), 0);
|
||||
for (int j = 0; j < segments.size(); j++) {
|
||||
vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
|
||||
}
|
||||
auto array_segments =
|
||||
tflite::CreateInt32Vector(builder_,
|
||||
builder_.CreateVector(vector_segments))
|
||||
.Union();
|
||||
auto indices = dim_metadata.indices();
|
||||
std::vector<int> vector_indices(indices.size(), 0);
|
||||
for (int j = 0; j < indices.size(); j++) {
|
||||
vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
|
||||
}
|
||||
auto array_indices = tflite::CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(vector_indices))
|
||||
.Union();
|
||||
fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
|
||||
builder_, tflite::DimensionType_SPARSE_CSR, 0,
|
||||
tflite::SparseIndexVector_Int32Vector, array_segments,
|
||||
tflite::SparseIndexVector_Int32Vector, array_indices);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> traversal_order(dim_size);
|
||||
for (int i = 0; i < dim_size; i++) {
|
||||
traversal_order[i] =
|
||||
s_attr.traversal_order()[i].dyn_cast<mlir::IntegerAttr>().getInt();
|
||||
}
|
||||
const int block_map_size = s_attr.block_map().size();
|
||||
std::vector<int> block_map(block_map_size);
|
||||
for (int i = 0; i < block_map_size; i++) {
|
||||
block_map[i] = s_attr.block_map()[i].dyn_cast<mlir::IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
return tflite::CreateSparsityParameters(
|
||||
builder_, builder_.CreateVector(traversal_order),
|
||||
builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer
|
||||
|
@ -69,6 +69,14 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
|
||||
[{Returns the indices of sparse operands.}],
|
||||
"std::vector<int>", "GetSparseOperands", (ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns the supported block size of float sparse operands.}],
|
||||
"std::vector<std::vector<int>>", "GetFloatBlockSize", (ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Returns the supported block size of quantized sparse operands.}],
|
||||
"std::vector<std::vector<int>>", "GetQuantizedBlockSize", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -722,17 +722,20 @@ def TFL_SparseConstOp : Op<TFL_Dialect, "pseudo_sparse_const", [NoSideEffect,
|
||||
an actual operation and it will be lowered to buffer instead.
|
||||
}];
|
||||
|
||||
let arguments = (ins ElementsAttr:$value, SparsityParameterAttr:$s_param);
|
||||
let arguments = (ins ElementsAttr:$value,
|
||||
SparsityParameterAttr:$s_param,
|
||||
ElementsAttr:$compressed_data);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &state, Attribute value, "
|
||||
"SparsityParameterAttr s_param",
|
||||
"SparsityParameterAttr s_param, Attribute compressed_data",
|
||||
[{
|
||||
state.addTypes(value.getType());
|
||||
state.addAttribute("value", value);
|
||||
state.addAttribute("s_param", s_param);
|
||||
state.addAttribute("compressed_data", compressed_data);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
@ -838,6 +841,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
std::vector<std::vector<int>> GetFloatBlockSize() { return {{1, 4}}; }
|
||||
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {{1, 16}}; }
|
||||
}];
|
||||
}
|
||||
|
||||
@ -3209,19 +3214,21 @@ def TFL_SparseQConstOp : Op<TFL_Dialect, "pseudo_sparse_qconst", [
|
||||
let arguments = (
|
||||
ins TensorTypeAttr:$qtype,
|
||||
ElementsAttr:$value,
|
||||
SparsityParameterAttr:$s_param
|
||||
SparsityParameterAttr:$s_param,
|
||||
ElementsAttr:$compressed_data
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &, OperationState &state, TypeAttr qtype, "
|
||||
"Attribute value, SparsityParameterAttr s_param",
|
||||
"Attribute value, SparsityParameterAttr s_param, Attribute compressed_data",
|
||||
[{
|
||||
state.addTypes(qtype.getValue());
|
||||
state.addAttribute("qtype", qtype);
|
||||
state.addAttribute("value", value);
|
||||
state.addAttribute("s_param", s_param);
|
||||
state.addAttribute("compressed_data", compressed_data);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_d2s",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite:framework",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
@ -57,6 +58,7 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
||||
}
|
||||
|
||||
PassManager pm(module->getContext());
|
||||
pm.addPass(TFL::CreateDenseToSparsePass());
|
||||
|
||||
if (failed(pm.run(module.get()))) {
|
||||
const std::string& err = statusHandler.ConsumeStatus().error_message();
|
||||
|
@ -16,10 +16,13 @@ limitations under the License.
|
||||
// This transformation pass convert dense tensor to sparse format.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The DenseToSparse Pass.
|
||||
@ -28,7 +31,226 @@ namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
// If sparsity level is below this threadshold, keep the tensor in dense format.
|
||||
const float kMinSparsityLevel = 0.3;
|
||||
// Heuristic to check if a block configuration is correct.
|
||||
const float kBlockOverRandomSparsityRatio = 0.9;
|
||||
|
||||
void PopulateEncodingParams(const std::vector<int>& block_size,
|
||||
std::vector<int>* traversal_order,
|
||||
std::vector<TfLiteDimensionType>* format,
|
||||
std::vector<int>* b_map, std::vector<int>* b_size) {
|
||||
*traversal_order = {0, 1};
|
||||
*format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||
*b_map = {};
|
||||
*b_size = {};
|
||||
int block_rank = 0;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
if (block_size[i] != 1) {
|
||||
traversal_order->push_back(block_rank + 2);
|
||||
format->push_back(kTfLiteDimDense);
|
||||
block_rank++;
|
||||
b_map->push_back(i);
|
||||
b_size->push_back(block_size[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float CalculateRandomSparsity(const ElementsAttr& attr,
|
||||
const ShapedType& type) {
|
||||
int num_elements = 1;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
num_elements *= type.getDimSize(i);
|
||||
}
|
||||
int num_zeros = 0;
|
||||
|
||||
if (type.getElementType().isF32()) {
|
||||
std::vector<float> data;
|
||||
data.reserve(type.getNumElements());
|
||||
for (const auto val : attr.getValues<float>()) data.push_back(val);
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
if (data[i] == 0) {
|
||||
num_zeros++;
|
||||
}
|
||||
}
|
||||
} else if (type.getElementType().isa<quant::QuantizedType>()) {
|
||||
std::vector<int8_t> data;
|
||||
data.reserve(type.getNumElements());
|
||||
for (const auto val : attr.getValues<int8_t>()) data.push_back(val);
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
if (data[i] == 0) {
|
||||
num_zeros++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 1.0 * num_zeros / num_elements;
|
||||
}
|
||||
|
||||
float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type,
|
||||
const std::vector<int>& block_size) {
|
||||
float sparsity = 0;
|
||||
std::vector<int> shape(2);
|
||||
shape[0] = type.getDimSize(0);
|
||||
shape[1] = type.getDimSize(1);
|
||||
|
||||
std::vector<int> traversal_order = {};
|
||||
std::vector<TfLiteDimensionType> format = {};
|
||||
std::vector<int> b_size = {};
|
||||
std::vector<int> b_map = {};
|
||||
PopulateEncodingParams(block_size, &traversal_order, &format, &b_map,
|
||||
&b_size);
|
||||
|
||||
if (type.getElementType().isF32()) {
|
||||
tflite::optimize::sparsity::FormatConverter<float> format_converter(
|
||||
shape, traversal_order, format, b_size, b_map);
|
||||
std::vector<float> data;
|
||||
data.reserve(type.getNumElements());
|
||||
for (const auto val : attr.getValues<float>()) data.push_back(val);
|
||||
format_converter.DenseToSparse(data.data());
|
||||
sparsity =
|
||||
1 - 1.0 * format_converter.GetData().size() / type.getNumElements();
|
||||
} else if (type.getElementType().isa<quant::QuantizedType>()) {
|
||||
tflite::optimize::sparsity::FormatConverter<int8_t> format_converter(
|
||||
shape, traversal_order, format, b_size, b_map);
|
||||
std::vector<int8_t> data;
|
||||
data.reserve(type.getNumElements());
|
||||
for (const auto val : attr.getValues<int8_t>()) data.push_back(val);
|
||||
format_converter.DenseToSparse(data.data());
|
||||
sparsity =
|
||||
1 - 1.0 * format_converter.GetData().size() / type.getNumElements();
|
||||
}
|
||||
|
||||
return sparsity;
|
||||
}
|
||||
|
||||
typedef struct InspectResult {
|
||||
// Whether the weight tensor is sparse enough to be compressed.
|
||||
bool can_compress;
|
||||
// If the weight tensor cannot be encoded in a block configuration that the op
|
||||
// supports, a Densify() op will be inserted afterwards to fall back to dense
|
||||
// execution.
|
||||
bool needs_densify;
|
||||
// Among the supported block configs of an op, which got selected to encode
|
||||
// the sparse weight.
|
||||
std::vector<int> selected_block_size;
|
||||
} InspectResult;
|
||||
|
||||
InspectResult InspectWeight(
|
||||
Operation* inst,
|
||||
const std::vector<std::vector<int>>& supported_block_size) {
|
||||
ElementsAttr attr;
|
||||
ShapedType type;
|
||||
InspectResult result = {};
|
||||
if (auto cst = dyn_cast<ConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
type = cst.getType().cast<ShapedType>();
|
||||
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
type = cst.getType().cast<ShapedType>();
|
||||
}
|
||||
|
||||
// TODO(b/147449640): Add ability to encode weights more than 2-D, e.g. Conv
|
||||
// weights.
|
||||
if (type.getRank() != 2) {
|
||||
result.can_compress = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
float random_sparsity = CalculateRandomSparsity(attr, type);
|
||||
if (random_sparsity < kMinSparsityLevel) {
|
||||
result.can_compress = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
result.can_compress = true;
|
||||
|
||||
float curr_sparsity = 0;
|
||||
std::vector<int> selected_block_size;
|
||||
result.needs_densify = true;
|
||||
for (const auto& block_size : supported_block_size) {
|
||||
curr_sparsity = CalculateBlockSparsity(attr, type, block_size);
|
||||
if (curr_sparsity / random_sparsity > kBlockOverRandomSparsityRatio) {
|
||||
selected_block_size = block_size;
|
||||
result.can_compress = true;
|
||||
result.needs_densify = false;
|
||||
result.selected_block_size = selected_block_size;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> BuildSparsityParameterAttribute(
|
||||
const std::vector<int>& block_size, Operation* inst, OpBuilder* builder,
|
||||
SparsityParameterAttr* s_param) {
|
||||
ElementsAttr attr;
|
||||
ShapedType type;
|
||||
if (auto cst = dyn_cast<ConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
type = cst.getType().cast<ShapedType>();
|
||||
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
type = cst.getType().cast<ShapedType>();
|
||||
}
|
||||
std::vector<int> shape(2);
|
||||
shape[0] = type.getDimSize(0);
|
||||
shape[1] = type.getDimSize(1);
|
||||
|
||||
std::vector<int> traversal_order = {};
|
||||
std::vector<TfLiteDimensionType> format = {};
|
||||
std::vector<int> b_size = {};
|
||||
std::vector<int> b_map = {};
|
||||
PopulateEncodingParams(block_size, &traversal_order, &format, &b_map,
|
||||
&b_size);
|
||||
|
||||
tflite::optimize::sparsity::FormatConverter<T> format_converter(
|
||||
shape, traversal_order, format, b_size, b_map);
|
||||
std::vector<T> data;
|
||||
data.reserve(type.getNumElements());
|
||||
for (const auto val : attr.getValues<T>()) data.push_back(val);
|
||||
format_converter.DenseToSparse(data.data());
|
||||
auto metadata = format_converter.GetDimMetadata();
|
||||
auto compressed_data = format_converter.GetData();
|
||||
const int dim_size = metadata.size() / 2;
|
||||
std::vector<Attribute> dim_metadata(traversal_order.size());
|
||||
for (int i = 0; i < dim_size; i++) {
|
||||
if (format[i] == kTfLiteDimDense) {
|
||||
dim_metadata[i] = DimensionMetadataAttr::get(
|
||||
builder->getStringAttr("DENSE"),
|
||||
builder->getI32IntegerAttr(metadata[2 * i][0]),
|
||||
builder->getArrayAttr({}), builder->getArrayAttr({}),
|
||||
builder->getContext());
|
||||
} else {
|
||||
dim_metadata[i] = DimensionMetadataAttr::get(
|
||||
builder->getStringAttr("SPARSE_CSR"), builder->getI32IntegerAttr(0),
|
||||
builder->getI32ArrayAttr(metadata[2 * i]),
|
||||
builder->getI32ArrayAttr(metadata[2 * i + 1]), builder->getContext());
|
||||
}
|
||||
}
|
||||
*s_param = SparsityParameterAttr::get(
|
||||
builder->getI32ArrayAttr(traversal_order),
|
||||
builder->getI32ArrayAttr(b_map), builder->getArrayAttr(dim_metadata),
|
||||
builder->getContext());
|
||||
|
||||
return compressed_data;
|
||||
}
|
||||
|
||||
// This pass encodes sparse weights in the model in the proper format, and adds
|
||||
// Densify() op if necessary. The general algorithm is:
|
||||
// 1. Get list of operands (weights) of an op that can be sparse.
|
||||
// 2. Get list of supported block configurations of the op.
|
||||
// 3. Calculate random sparsity of the weight.
|
||||
// 3.1. If sparsity level is below the encoding threshold, keep in dense.
|
||||
// 3.2. If sparsity level is above the encoding threshold, go to 4.
|
||||
// 4. Try to encode the weight with supported block configurations. If the
|
||||
// weight was pruned with the same block config, the blocked sparsity level
|
||||
// should match the random sparsity.
|
||||
// 4.1. Return the matching block config if found.
|
||||
// 4.2. If no matching block config is found, encode the weight with random
|
||||
// sparsity, and add Densify() op to fall back to dense execution.
|
||||
struct DenseToSparse : public PassWrapper<DenseToSparse, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
@ -39,19 +261,66 @@ void DenseToSparse::runOnFunction() {
|
||||
|
||||
func.walk([&](SparseOpInterface sparse_op) {
|
||||
const auto& sparse_operands = sparse_op.GetSparseOperands();
|
||||
std::vector<std::vector<int>> supported_block_size;
|
||||
for (const int operand : sparse_operands) {
|
||||
auto* op = sparse_op.getOperation();
|
||||
const auto& value = op->getOperand(operand);
|
||||
builder.setInsertionPoint(op);
|
||||
if (auto* inst = value.getDefiningOp()) {
|
||||
// Replace defining op with SparseConst or SparseQConst.
|
||||
// TODO(yunluli): Implement.
|
||||
|
||||
auto* inst = value.getDefiningOp();
|
||||
if (!inst) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(yunluli): Implement.
|
||||
bool needs_densify = false;
|
||||
if (isa<ConstOp>(inst)) {
|
||||
supported_block_size = sparse_op.GetFloatBlockSize();
|
||||
} else if (isa<QConstOp>(inst)) {
|
||||
supported_block_size = sparse_op.GetQuantizedBlockSize();
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (needs_densify) {
|
||||
InspectResult result = InspectWeight(inst, supported_block_size);
|
||||
if (!result.can_compress) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The weight is not block sparse. Encode with random sparsity.
|
||||
if (result.selected_block_size.empty()) {
|
||||
result.selected_block_size = {1, 1};
|
||||
}
|
||||
|
||||
builder.setInsertionPoint(op);
|
||||
SparsityParameterAttr s_param;
|
||||
if (auto cst = dyn_cast<ConstOp>(inst)) {
|
||||
std::vector<float> compressed_data =
|
||||
BuildSparsityParameterAttribute<float>(result.selected_block_size,
|
||||
inst, &builder, &s_param);
|
||||
auto compressed_data_type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(compressed_data.size())},
|
||||
builder.getF32Type());
|
||||
auto new_value = DenseElementsAttr::get<float>(compressed_data_type,
|
||||
compressed_data);
|
||||
auto s_const = builder.create<SparseConstOp>(op->getLoc(), cst.value(),
|
||||
s_param, new_value);
|
||||
value.replaceAllUsesWith(s_const.getResult());
|
||||
cst.erase();
|
||||
} else if (auto cst = dyn_cast<QConstOp>(inst)) {
|
||||
std::vector<int8_t> compressed_data =
|
||||
BuildSparsityParameterAttribute<int8_t>(result.selected_block_size,
|
||||
inst, &builder, &s_param);
|
||||
auto compressed_data_type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(compressed_data.size())},
|
||||
builder.getIntegerType(8, true));
|
||||
auto new_value = DenseElementsAttr::get<int8_t>(compressed_data_type,
|
||||
compressed_data);
|
||||
auto s_qconst = builder.create<SparseQConstOp>(
|
||||
op->getLoc(), cst.qtypeAttr(), cst.value(), s_param, new_value);
|
||||
value.replaceAllUsesWith(s_qconst.getResult());
|
||||
cst.erase();
|
||||
}
|
||||
|
||||
if (result.needs_densify) {
|
||||
const auto value = op->getOperand(operand);
|
||||
auto densify = builder.create<DensifyOp>(op->getLoc(), value);
|
||||
value.replaceAllUsesWith(densify);
|
||||
densify.setOperand(value);
|
||||
|
@ -76,3 +76,20 @@ py_test(
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sparsify_model_test",
|
||||
srcs = ["sparsify_model_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user