Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
This commit is contained in:
parent
931353c5f7
commit
12e164d1e7
@ -25,10 +25,13 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ConvertExpandDimsToReshape::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto expand_it = model->operators.begin() + op_index;
|
||||
if (expand_it->get()->type != OperatorType::kExpandDims) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
ExpandDimsOperator* expand_op =
|
||||
static_cast<ExpandDimsOperator*>(expand_it->get());
|
||||
@ -38,18 +41,18 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
|
||||
const auto& input_array = model->GetArray(expand_op->inputs[0]);
|
||||
if (!input_array.has_shape()) {
|
||||
// Yield until input dims have been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& axis_array = model->GetArray(expand_op->inputs[1]);
|
||||
if (!axis_array.has_shape()) {
|
||||
// Yield until input axis array shape has been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1);
|
||||
if (!axis_array.buffer) {
|
||||
// Yield until the input axis array is constant
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
|
||||
std::vector<int> reshape_dims(input_array.shape().dims());
|
||||
@ -90,7 +93,8 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
|
||||
CHECK_EQ(expand_it->get(), expand_op);
|
||||
model->operators.erase(expand_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,29 +24,32 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto conv_it = model->operators.begin() + op_index;
|
||||
if (conv_it->get()->type != OperatorType::kConv) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
|
||||
if (conv_op->stride_width != conv_op->stride_height) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if ((conv_op->dilation_width_factor != 1) ||
|
||||
(conv_op->dilation_height_factor != 1)) {
|
||||
// Depthwise conv does not support dilation
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto& input_array = model->GetArray(conv_op->inputs[0]);
|
||||
if (!input_array.has_shape()) {
|
||||
// Shapes not propagated yet
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (input_array.shape().dims(3) != 1) {
|
||||
// Not a pure convolution: Conv does accumulation across the depth
|
||||
// dimension.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& weights_name = conv_op->inputs[1];
|
||||
@ -56,15 +59,15 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
|
||||
"Not changing %s to DepthwiseConv because the weights is consumed by "
|
||||
"another op.",
|
||||
LogName(*conv_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto& weights_array = model->GetArray(weights_name);
|
||||
if (!weights_array.buffer) {
|
||||
// Yield until the weights are resolved as a constant array.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (weights_array.data_type != ArrayDataType::kFloat) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
|
||||
AddMessageF(
|
||||
@ -112,7 +115,8 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
*weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
|
||||
weights_buffer.data = depthwise_conv_weights_data;
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -86,9 +86,12 @@ TransposeOperator* CreateTransposeFromReorderAxes(
|
||||
|
||||
// Converts ReorderAxes into Transpose and Reshape which are compatible with the
|
||||
// TFLite interpreter.
|
||||
bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ConvertReorderAxes::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto reorder_it = model->operators.begin() + op_index;
|
||||
if (reorder_it->get()->type != OperatorType::kReorderAxes) return false;
|
||||
if (reorder_it->get()->type != OperatorType::kReorderAxes)
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
|
||||
CHECK_EQ(reorder_op->inputs.size(), 1);
|
||||
@ -113,8 +116,9 @@ bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
|
||||
// Yield if input array contains constants or if output array size has not
|
||||
// been adjusted to reflect the permutations in ReorderAxes. ReorderAxes will
|
||||
// be merged into a constant array when possible.
|
||||
if (IsConstantParameterArray(*model, constant_input_array_name)) return false;
|
||||
if (!output_array.has_shape()) return false;
|
||||
if (IsConstantParameterArray(*model, constant_input_array_name))
|
||||
return ::tensorflow::Status::OK();
|
||||
if (!output_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
|
||||
const auto input_axes_order = reorder_op->input_axes_order;
|
||||
const auto output_axes_order = reorder_op->output_axes_order;
|
||||
@ -143,7 +147,8 @@ bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
|
||||
CHECK_EQ(reorder_it->get(), reorder_op);
|
||||
model->operators.erase(reorder_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -30,10 +30,13 @@ namespace toco {
|
||||
// means that the data layout will never change with this op, just the shape.
|
||||
// By converting these to reshapes once we have run shape propagation we allow
|
||||
// standard reshape optimization transforms to do their magic.
|
||||
bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ConvertSqueezeToReshape::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto squeeze_it = model->operators.begin() + op_index;
|
||||
if (squeeze_it->get()->type != OperatorType::kSqueeze) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto squeeze_op = static_cast<SqueezeOperator*>(squeeze_it->get());
|
||||
CHECK_EQ(squeeze_op->inputs.size(), 1);
|
||||
@ -42,16 +45,16 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
|
||||
const auto& input_array = model->GetArray(squeeze_op->inputs[0]);
|
||||
if (!input_array.has_shape()) {
|
||||
// Yield until input dims have been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (input_array.shape().dimensions_count() == 0) {
|
||||
// Input array cannot be 0-D.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!model->HasArray(squeeze_op->outputs[0]) ||
|
||||
!model->GetArray(squeeze_op->outputs[0]).has_shape()) {
|
||||
// Yield until shape propagation has set the output shape for us.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// We use the output shape that has been calculated by shape propagation.
|
||||
@ -59,7 +62,7 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Empty shapes will not work as empty data arrays.
|
||||
if (output_shape.dimensions_count() == 0) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
auto* reshape_op = new TensorFlowReshapeOperator;
|
||||
@ -79,7 +82,8 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
|
||||
CHECK_EQ(squeeze_it->get(), squeeze_op);
|
||||
model->operators.erase(squeeze_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -20,10 +20,13 @@ namespace toco {
|
||||
|
||||
// This pass will convert an AddN operator with only 2 inputs into a regular Add
|
||||
// operator, to which more optimizations may apply.
|
||||
bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ConvertTrivialAddNToAdd::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto addn_it = model->operators.begin() + op_index;
|
||||
if (addn_it->get()->type != OperatorType::kAddN) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
AddNOperator* addn_op = static_cast<AddNOperator*>(addn_it->get());
|
||||
CHECK_GE(addn_op->inputs.size(), 2);
|
||||
@ -31,7 +34,7 @@ bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// We only reduce AddN with N=2 to a regular Add.
|
||||
if (addn_op->inputs.size() != 2) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Copy inputs & outputs to regular Add.
|
||||
@ -45,7 +48,8 @@ bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
|
||||
addn_it = add_it + 1;
|
||||
CHECK_EQ(addn_it->get(), addn_op);
|
||||
model->operators.erase(addn_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,27 +25,30 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ConvertTrivialPackToReshape::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto pack_it = model->operators.begin() + op_index;
|
||||
if (pack_it->get()->type != OperatorType::kPack) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* pack_op = static_cast<PackOperator*>(pack_it->get());
|
||||
if (pack_op->inputs.size() > 1) {
|
||||
// Not trivial.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(pack_op->outputs.size(), 1);
|
||||
|
||||
const auto& input_array = model->GetArray(pack_op->inputs[0]);
|
||||
if (!input_array.has_shape()) {
|
||||
// Yield until input dims have been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (input_array.shape().dimensions_count() == 0) {
|
||||
// Input array cannot be 0-D.
|
||||
// (Unsure if this is TF behavior, but was required to get a test to pass.)
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Converting trivial %s to a reshape", LogName(*pack_op));
|
||||
@ -75,7 +78,8 @@ bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) {
|
||||
CHECK_EQ(pack_it->get(), pack_op);
|
||||
model->operators.erase(pack_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -21,10 +21,13 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ConvertTrivialTileToConcat::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto tile_it = model->operators.begin() + op_index;
|
||||
if (tile_it->get()->type != OperatorType::kTile) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* tile_op = static_cast<TransposeOperator*>(tile_it->get());
|
||||
|
||||
@ -34,13 +37,13 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
|
||||
if (!input_array.has_shape() || !multiples_array.has_shape() ||
|
||||
!output_array.has_shape()) {
|
||||
// Yield until PropagateFixedSizes has been run on this op.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
|
||||
|
||||
if (!multiples_array.buffer) {
|
||||
// Yield until the multiples is constant.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
std::vector<int32> const& multiples =
|
||||
multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
@ -59,7 +62,7 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
|
||||
// The tile is non-trivial. Good luck.
|
||||
AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)",
|
||||
LogName(*tile_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// The tile is like a concat.
|
||||
@ -88,7 +91,8 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
|
||||
CHECK_EQ(tile_it->get(), tile_op);
|
||||
model->operators.erase(tile_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -48,10 +48,13 @@ bool TransposeAffectsMemoryOrder(std::vector<int> perm,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ConvertTrivialTransposeToReshape::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto transpose_it = model->operators.begin() + op_index;
|
||||
if (transpose_it->get()->type != OperatorType::kTranspose) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
TransposeOperator* transpose_op =
|
||||
static_cast<TransposeOperator*>(transpose_it->get());
|
||||
@ -60,14 +63,14 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
|
||||
const auto& output_array = model->GetArray(transpose_op->outputs[0]);
|
||||
if (!input_array.has_shape() || !output_array.has_shape()) {
|
||||
// Yield until PropagateFixedSizes has been run on this op.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
|
||||
|
||||
// Check that the permutation has propogated.
|
||||
std::vector<int> const& perm = transpose_op->perm;
|
||||
if (perm.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// This transpose is trivial if non-unitary dimensions remain in the same
|
||||
@ -76,7 +79,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
|
||||
std::vector<int> const& output_dims = output_array.shape().dims();
|
||||
|
||||
if (TransposeAffectsMemoryOrder(perm, input_dims)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// This transpose is trivial. Replace it with a Reshape op.
|
||||
@ -109,7 +112,8 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
|
||||
CHECK_EQ(transpose_it->get(), transpose_op);
|
||||
model->operators.erase(transpose_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -73,18 +73,22 @@ bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status CreateIm2colArrays::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
|
||||
switch (op->type) {
|
||||
case OperatorType::kConv:
|
||||
return ProcessConvOperator(model, static_cast<ConvOperator*>(op));
|
||||
*modified = ProcessConvOperator(model, static_cast<ConvOperator*>(op));
|
||||
return ::tensorflow::Status::OK();
|
||||
case OperatorType::kTransposeConv:
|
||||
return ProcessTransposeConvOperator(
|
||||
*modified = ProcessTransposeConvOperator(
|
||||
model, static_cast<TransposeConvOperator*>(op));
|
||||
return ::tensorflow::Status::OK();
|
||||
default:
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -186,24 +186,27 @@ bool DequantizeArray(const string& array_name,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool Dequantize::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status Dequantize::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto op_it = model->operators.begin() + op_index;
|
||||
auto* op = op_it->get();
|
||||
|
||||
if (op->type == OperatorType::kDequantize) {
|
||||
auto& input_array = model->GetArray(op->inputs[0]);
|
||||
if (input_array.data_type == ArrayDataType::kFloat) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (input_array.final_data_type != ArrayDataType::kFloat) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
input_array.data_type = ArrayDataType::kFloat;
|
||||
input_array.quantization_params = nullptr;
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
output_array.data_type = ArrayDataType::kFloat;
|
||||
output_array.quantization_params = nullptr;
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
std::vector<string> arrays;
|
||||
@ -220,7 +223,8 @@ bool Dequantize::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
*modified = changed;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,21 +25,23 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status DropFakeQuant::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto fakequant_it = model->operators.begin() + op_index;
|
||||
auto* fakequant_base_op = fakequant_it->get();
|
||||
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
|
||||
|
||||
if (!fakequant_op->minmax) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& output_array = model->GetArray(fakequant_op->outputs[0]);
|
||||
if (!output_array.minmax) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Drop min/max inputs
|
||||
@ -50,7 +52,8 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
fakequant_op->inputs.resize(1);
|
||||
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -19,15 +19,17 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status DropIm2colArrays::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto conv_it = model->operators.begin() + op_index;
|
||||
if (conv_it->get()->type != OperatorType::kConv) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
|
||||
if (conv_op->outputs.size() < 2) {
|
||||
// Conv op does not have im2col.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Drop the im2col array.
|
||||
@ -36,7 +38,8 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
|
||||
conv_op->outputs.resize(1);
|
||||
AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -62,17 +62,20 @@ bool ProcessLinearOperator(Model* model, Operator* op) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status EnsureBiasVectors::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto* op = model->operators[op_index].get();
|
||||
if (op->type == OperatorType::kConv ||
|
||||
op->type == OperatorType::kDepthwiseConv ||
|
||||
op->type == OperatorType::kFullyConnected) {
|
||||
if (ProcessLinearOperator(model, op)) {
|
||||
AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -108,8 +108,9 @@ namespace toco {
|
||||
// we can foresee these 'fast int8 kernels' to remain important to have into
|
||||
// the 2020s.
|
||||
//
|
||||
bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
|
||||
std::size_t op_index) {
|
||||
::tensorflow::Status EnsureUint8WeightsSafeForFastInt8Kernels::Run(
|
||||
Model* model, std::size_t op_index, bool* modified) {
|
||||
*modified = false;
|
||||
const auto& op = *model->operators[op_index];
|
||||
int weights_index = 0;
|
||||
switch (op.type) {
|
||||
@ -148,16 +149,16 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
|
||||
// That's why at the moment we only handle operators that use a GEMM
|
||||
// (Conv, fully-connected --- note that LSTM merely wraps a
|
||||
// fully-connected operator).
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const string& name = op.inputs[weights_index];
|
||||
auto& array = model->GetArray(name);
|
||||
if (!array.buffer) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (array.data_type != ArrayDataType::kUint8) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto& buffer_data = array.GetMutableBuffer<ArrayDataType::kUint8>().data;
|
||||
|
||||
@ -212,7 +213,8 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
|
||||
AddMessageF("Tweaked weights values for %s", LogName(op));
|
||||
}
|
||||
|
||||
return changed;
|
||||
*modified = changed;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,27 +25,30 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status FuseActivationFunctions::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto ac_it = model->operators.begin() + op_index;
|
||||
const auto* ac_op = ac_it->get();
|
||||
|
||||
if (ac_op->type != OperatorType::kRelu6 &&
|
||||
ac_op->type != OperatorType::kRelu1 &&
|
||||
ac_op->type != OperatorType::kRelu) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Find the op producing the array passed to this activation function
|
||||
Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]);
|
||||
|
||||
if (!op) return false;
|
||||
if (!op) return ::tensorflow::Status::OK();
|
||||
|
||||
if (CountTrueOutputs(*model, *op) > 1) {
|
||||
AddMessageF(
|
||||
"Not fusing activation function %s into %s because it has more than "
|
||||
"one consumed output",
|
||||
LogName(*ac_op), LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(op->outputs[0], ac_op->inputs[0]);
|
||||
@ -57,7 +60,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing activation function into %s because it is consumed by more "
|
||||
"than 1 other operator",
|
||||
LogName(*ac_op), LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!IsDiscardableArray(*model, op->outputs[0])) {
|
||||
@ -65,7 +68,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing activation function %s into %s because output %s it is not "
|
||||
"discardable",
|
||||
LogName(*ac_op), LogName(*op), op->outputs[0]);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
|
||||
@ -73,7 +76,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing activation function %s into %s because it already has a "
|
||||
"fused activation function",
|
||||
LogName(*ac_op), LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!OperatorSupportsFusedActivation(op->type)) {
|
||||
@ -81,7 +84,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing activation function %s because the %s op doesn't support "
|
||||
"it",
|
||||
LogName(*ac_op), LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Fusing activation function %s into the preceding %s",
|
||||
@ -98,7 +101,8 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
model->EraseArray(ac_op->inputs[0]);
|
||||
op->outputs[0] = ac_op->outputs[0];
|
||||
model->operators.erase(ac_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -150,14 +150,17 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status FuseBinaryIntoFollowingAffine::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto binary_it = model->operators.begin() + op_index;
|
||||
auto* binary_op = binary_it->get();
|
||||
if (binary_op->type != OperatorType::kAdd &&
|
||||
binary_op->type != OperatorType::kMul &&
|
||||
binary_op->type != OperatorType::kSub &&
|
||||
binary_op->type != OperatorType::kDiv) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(binary_op->inputs.size(), 2);
|
||||
@ -175,12 +178,12 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
};
|
||||
if (!is_input_constant[0] && !is_input_constant[1]) {
|
||||
// Neither input is constant, so nothing we can fuse into a constant.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (is_input_constant[0] && is_input_constant[1]) {
|
||||
// Both inputs are constants. That's a job for constants
|
||||
// propagation, not for us to handle here.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
|
||||
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
|
||||
@ -192,7 +195,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
if (index_of_constant_input != 1) {
|
||||
AddMessageF("Not fusing %s because the denominator is not constant",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,7 +207,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing %s into the following affine op, because we only know "
|
||||
"how to do so when the constant operand is a scalar",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -212,7 +215,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
FusedActivationFunctionType::kNone) {
|
||||
AddMessageF("Not fusing %s because it has a fused activation function",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]);
|
||||
@ -221,7 +224,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF(
|
||||
"Not fusing %s because it is not consumed by exactly one other op",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (following_op->type != OperatorType::kConv &&
|
||||
@ -231,14 +234,14 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing %s because the following %s is not of one of the supported "
|
||||
"types",
|
||||
LogName(*binary_op), LogName(*following_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (following_op->inputs.size() < 3) {
|
||||
AddMessageF(
|
||||
"Not fusing %s because the following %s does not have a bias vector",
|
||||
LogName(*following_op), LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& weights = model->GetArray(following_op->inputs[1]);
|
||||
@ -248,7 +251,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing %s because the following %s has non-constant weights or "
|
||||
"bias arrays",
|
||||
LogName(*binary_op), LogName(*following_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Try to fuse the binary params into the following op's params
|
||||
@ -260,7 +263,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF(
|
||||
"Not fusing %s because the following %s does not use VALID padding",
|
||||
LogName(*binary_op), LogName(*following_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
if (following_op->type == OperatorType::kDepthwiseConv) {
|
||||
@ -269,7 +272,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF(
|
||||
"Not fusing %s because the following %s does not use VALID padding",
|
||||
LogName(*binary_op), LogName(*following_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op,
|
||||
@ -294,7 +297,8 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
|
||||
model->EraseArray(old_constant_param_name);
|
||||
}
|
||||
model->operators.erase(binary_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -188,14 +188,17 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status FuseBinaryIntoPrecedingAffine::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto binary_it = model->operators.begin() + op_index;
|
||||
const auto* binary_op = binary_it->get();
|
||||
if (binary_op->type != OperatorType::kAdd &&
|
||||
binary_op->type != OperatorType::kMul &&
|
||||
binary_op->type != OperatorType::kSub &&
|
||||
binary_op->type != OperatorType::kDiv) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(binary_op->inputs.size(), 2);
|
||||
@ -213,12 +216,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
};
|
||||
if (!is_input_constant[0] && !is_input_constant[1]) {
|
||||
// Neither input is constant, so nothing we can fuse into a constant.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (is_input_constant[0] && is_input_constant[1]) {
|
||||
// Both inputs are constants. That's a job for constants
|
||||
// propagation, not for us to handle here.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
|
||||
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
|
||||
@ -230,7 +233,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
if (index_of_constant_input != 1) {
|
||||
AddMessageF("Not fusing %s because the denominator is not constant",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,12 +242,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
if (!preceding_op) {
|
||||
AddMessageF("Not fusing %s because it is not the output of another op",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
for (const string& output_array : model->flags.output_arrays()) {
|
||||
if (preceding_op->outputs[0] == output_array) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -255,7 +258,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing %s because the preceding %s is not of one of the supported "
|
||||
"types",
|
||||
LogName(*binary_op), LogName(*preceding_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (preceding_op->fused_activation_function !=
|
||||
@ -264,14 +267,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing %s because the preceding %s has a fused activation "
|
||||
"function",
|
||||
LogName(*binary_op), LogName(*preceding_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (preceding_op->inputs.size() < 3) {
|
||||
AddMessageF(
|
||||
"Not fusing %s because the preceding %s does not have a bias vector",
|
||||
LogName(*binary_op), LogName(*preceding_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& weights_name = preceding_op->inputs[1];
|
||||
@ -289,14 +292,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing %s because the preceding %s has a non-constant bias "
|
||||
"array",
|
||||
LogName(*binary_op), LogName(*preceding_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (count_ops_consuming_bias > 1) {
|
||||
AddMessageF(
|
||||
"Not fusing %s because the bias of the preceding %s is consumed by "
|
||||
"another op",
|
||||
LogName(*binary_op), LogName(*preceding_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
} else {
|
||||
if (!weights.buffer || !bias.buffer) {
|
||||
@ -304,14 +307,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing %s because the preceding %s has non-constant weights or "
|
||||
"bias arrays",
|
||||
LogName(*binary_op), LogName(*preceding_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) {
|
||||
AddMessageF(
|
||||
"Not fusing %s because the weights or bias of the preceding %s is "
|
||||
"consumed by another op",
|
||||
LogName(*binary_op), LogName(*preceding_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -323,7 +326,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
"Not fusing %s because the output of the preceding %s is consumed by "
|
||||
"another op",
|
||||
LogName(*binary_op), LogName(*preceding_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
|
||||
@ -352,7 +355,8 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
|
||||
model->EraseArray(old_constant_param_name);
|
||||
}
|
||||
model->operators.erase(binary_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -51,19 +51,22 @@ bool IsBroadcastingOp(const Model& model, Operator* op) {
|
||||
// Finds an operation that looks like a broadcast (concat of the same sources
|
||||
// along the last dimension) and drops it by relying on the ability of certain
|
||||
// binary ops to perform an implicit broadcast.
|
||||
bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status FuseBroadcastIntoFollowingBinary::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto binary_it = model->operators.begin() + op_index;
|
||||
auto* binary_op = binary_it->get();
|
||||
|
||||
// Test for binary ops of types that we know how to resolve
|
||||
if (binary_op->inputs.size() != 2) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (binary_op->type != OperatorType::kAdd &&
|
||||
binary_op->type != OperatorType::kMul &&
|
||||
binary_op->type != OperatorType::kSub &&
|
||||
binary_op->type != OperatorType::kDiv) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// NOTE: either of these ops may be nullptr if the input array is constant.
|
||||
@ -78,14 +81,14 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
|
||||
if (!is_op_0_broadcast && !is_op_1_broadcast) {
|
||||
// Neither input is a broadcast-looking thing.
|
||||
AddMessageF("Neither input looks broadcasty");
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
} else if (is_op_0_broadcast && is_op_1_broadcast) {
|
||||
AddMessageF(
|
||||
"Unable to fuse broadcast into %s as both inputs (%s, %s) are "
|
||||
"broadcasts",
|
||||
LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)",
|
||||
op[1] ? LogName(*op[1]) : "(?)");
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
int broadcast_index = is_op_0_broadcast ? 0 : 1;
|
||||
|
||||
@ -96,7 +99,8 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
|
||||
binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0];
|
||||
|
||||
// We leave the broadcast op in; it'll get cleaned up if it's not used later.
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -142,7 +142,7 @@ bool GraphTransformationsPass(int increment, Model* model,
|
||||
for (const auto& transformation : transformations) {
|
||||
CHECK(!changed_now);
|
||||
CHECK(transformation->Messages().empty());
|
||||
changed_now = transformation->Run(model, op_index);
|
||||
CHECK(transformation->Run(model, op_index, &changed_now).ok());
|
||||
const char* made_a_change_msg =
|
||||
changed_now ? "made a change" : "did NOT make a change";
|
||||
const int log_level =
|
||||
|
@ -27,7 +27,8 @@ namespace toco {
|
||||
|
||||
class GraphTransformation {
|
||||
public:
|
||||
virtual bool Run(Model* model, std::size_t op_index) = 0;
|
||||
virtual ::tensorflow::Status Run(Model* model, std::size_t op_index,
|
||||
bool* modified) = 0;
|
||||
virtual const char* Name() const = 0;
|
||||
virtual ~GraphTransformation() {}
|
||||
// Returns the list of messages that this graph transformation
|
||||
@ -104,11 +105,12 @@ class GraphTransformationsSet {
|
||||
void RunGraphTransformations(Model* model, const string& message,
|
||||
const GraphTransformationsSet& transformations);
|
||||
|
||||
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
|
||||
class GTName : public GraphTransformation { \
|
||||
public: \
|
||||
bool Run(Model* model, std::size_t op_index) override; \
|
||||
const char* Name() const override { return #GTName; } \
|
||||
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
|
||||
class GTName : public GraphTransformation { \
|
||||
public: \
|
||||
::tensorflow::Status Run(Model* model, std::size_t op_index, \
|
||||
bool* modified) override; \
|
||||
const char* Name() const override { return #GTName; } \
|
||||
};
|
||||
|
||||
// List of all graph transformations
|
||||
@ -200,7 +202,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes)
|
||||
|
||||
class PropagateDefaultMinMax : public GraphTransformation {
|
||||
public:
|
||||
bool Run(Model* model, std::size_t op_index) override;
|
||||
::tensorflow::Status Run(Model* model, std::size_t op_index,
|
||||
bool* modified) override;
|
||||
const char* Name() const override { return "PropagateDefaultMinMax"; }
|
||||
|
||||
bool has_any_ranges_defined() const { return !type_ranges_.empty(); }
|
||||
@ -218,7 +221,8 @@ class PropagateDefaultMinMax : public GraphTransformation {
|
||||
|
||||
class RemoveTrivialReshape : public GraphTransformation {
|
||||
public:
|
||||
bool Run(Model* model, std::size_t op_index) override;
|
||||
::tensorflow::Status Run(Model* model, std::size_t op_index,
|
||||
bool* modified) override;
|
||||
const char* Name() const override { return "RemoveTrivialReshape"; }
|
||||
bool treat_expand_dims_as_trivial() const {
|
||||
return treat_expand_dims_as_trivial_;
|
||||
@ -233,7 +237,8 @@ class RemoveTrivialReshape : public GraphTransformation {
|
||||
|
||||
class ResolveConstantFakeQuant : public GraphTransformation {
|
||||
public:
|
||||
bool Run(Model* model, std::size_t op_index) override;
|
||||
::tensorflow::Status Run(Model* model, std::size_t op_index,
|
||||
bool* modified) override;
|
||||
const char* Name() const override { return "ResolveConstantFakeQuant"; }
|
||||
|
||||
// True if the num_bits should adjust the final data type.
|
||||
@ -250,7 +255,8 @@ class ResolveConstantFakeQuant : public GraphTransformation {
|
||||
|
||||
class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
|
||||
public:
|
||||
bool Run(Model* model, std::size_t op_index) override;
|
||||
::tensorflow::Status Run(Model* model, std::size_t op_index,
|
||||
bool* modified) override;
|
||||
const char* Name() const override {
|
||||
return "EnsureUint8WeightsSafeForFastInt8Kernels";
|
||||
}
|
||||
@ -267,7 +273,8 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
|
||||
|
||||
class IdentifyDilatedConv : public GraphTransformation {
|
||||
public:
|
||||
bool Run(Model* model, std::size_t op_index) override;
|
||||
::tensorflow::Status Run(Model* model, std::size_t op_index,
|
||||
bool* modified) override;
|
||||
const char* Name() const override { return "IdentifyDilatedConv"; }
|
||||
bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
|
||||
void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
|
||||
|
@ -372,7 +372,9 @@ bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status HardcodeMinMax::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
bool changed = false;
|
||||
@ -467,7 +469,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
|
||||
if (changed) {
|
||||
AddMessageF("Hardcoded min-max through %s", LogName(*op));
|
||||
}
|
||||
return changed;
|
||||
*modified = changed;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -168,7 +168,10 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status IdentifyDilatedConv::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
auto* stb_op = it->get();
|
||||
|
||||
@ -176,17 +179,17 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
|
||||
// ***************************************************************************
|
||||
// SpaceToBatch Op.
|
||||
if (stb_op->type != OperatorType::kSpaceToBatchND) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (stb_op->inputs.size() != 3) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(stb_op->outputs.size(), 1);
|
||||
// Extract the dilation factor from Input[1] of SpaceToBatch
|
||||
// TODO(mjmatthews): Support 2D dilation factors.
|
||||
const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
|
||||
if (!block_shape_array.buffer) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
|
||||
int dilation_factor =
|
||||
@ -195,7 +198,7 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
|
||||
// Expand Op
|
||||
auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
|
||||
if (!post_stb_op) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
bool has_expand_op = false;
|
||||
if (post_stb_op->type == OperatorType::kExpandDims) {
|
||||
@ -229,7 +232,8 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
*modified = changed;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -39,7 +39,10 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status IdentifyL2Normalization::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto div_it = model->operators.begin() + op_index;
|
||||
const auto* div_or_mul_op = div_it->get();
|
||||
OperatorType expected_op_type_producing_div_or_mul_input;
|
||||
@ -48,7 +51,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
|
||||
} else if (div_or_mul_op->type == OperatorType::kMul) {
|
||||
expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt;
|
||||
} else {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(div_or_mul_op->inputs.size(), 2);
|
||||
Operator* op_producing_div_or_mul_input[2] = {
|
||||
@ -58,14 +61,14 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
|
||||
if (!op_producing_div_or_mul_input[1] ||
|
||||
op_producing_div_or_mul_input[1]->type !=
|
||||
expected_op_type_producing_div_or_mul_input) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
|
||||
CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
|
||||
Operator* op_producing_sqrt_or_rsqrt_input =
|
||||
GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
|
||||
if (!op_producing_sqrt_or_rsqrt_input) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// There may be an Add or a Maximum here, adding or clamping to a "small"
|
||||
@ -105,7 +108,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
|
||||
" because the operator producing the input to the square root, %s,"
|
||||
", does not match the expected pattern",
|
||||
LogName(*op_producing_sqrt_or_rsqrt_input));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -116,7 +119,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
|
||||
"Giving up trying to identify L2Normalization subgraph: "
|
||||
"expected Sum op, got %s",
|
||||
LogName(*sum_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
|
||||
@ -125,7 +128,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
|
||||
"Giving up trying to identify L2Normalization subgraph: "
|
||||
"expected Square op, got %s",
|
||||
LogName(*square_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(square_op->inputs.size(), 1);
|
||||
@ -135,7 +138,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
|
||||
"Giving up trying to identify L2Normalization subgraph: %s does not "
|
||||
"take the same input as the Mul/Div node",
|
||||
LogName(*square_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Create and emplace the new L2Normalization
|
||||
@ -162,7 +165,8 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
|
||||
model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
|
||||
model->EraseArray(div_or_mul_op->inputs[1]);
|
||||
model->operators.erase(FindOperator(model, div_or_mul_op));
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -38,11 +38,13 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status IdentifyL2Pool::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto sqrt_it = model->operators.begin() + op_index;
|
||||
const auto* sqrt_op = sqrt_it->get();
|
||||
if (sqrt_op->type != OperatorType::kSqrt) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(sqrt_op->inputs.size(), 1);
|
||||
@ -56,7 +58,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF(
|
||||
"Giving up trying to identify L2Pool subgraph: "
|
||||
"expected AveragePool op, but Sqrt op has no preceding op");
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (prev_to_sqrt_op->type != OperatorType::kAveragePool) {
|
||||
@ -64,7 +66,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
|
||||
"Giving up trying to identify L2Pool subgraph: "
|
||||
"expected AveragePool op, got %s",
|
||||
LogName(*prev_to_sqrt_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
avpool_op = static_cast<const AveragePoolOperator*>(prev_to_sqrt_op);
|
||||
@ -77,7 +79,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
|
||||
"Giving up trying to identify L2Pool subgraph: "
|
||||
"expected Square op, got %s",
|
||||
LogName(*square_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Create and emplace L2Pool node.
|
||||
@ -107,7 +109,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
|
||||
model->operators.erase(FindOperator(model, avpool_op));
|
||||
model->operators.erase(FindOperator(model, sqrt_op));
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -132,7 +132,9 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status IdentifyLstmCell::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
// This LSTM cell identification method is not invariant to commutation of
|
||||
// commutative operator inputs. For example, if input[0] and input[1] of the
|
||||
// final output multiplication were swapped, this method would not identify it
|
||||
@ -143,13 +145,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
|
||||
auto op_it = model->operators.begin() + op_index;
|
||||
Operator* final_output_mul = op_it->get();
|
||||
if (final_output_mul->type != OperatorType::kMul) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
Operator *state_output_tanh, *fc_output_sig;
|
||||
if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
|
||||
&state_output_tanh, OperatorType::kLogistic,
|
||||
&fc_output_sig)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// State output TanH
|
||||
@ -158,7 +160,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
|
||||
Operator* state_combine_add;
|
||||
if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd,
|
||||
&state_combine_add)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// State forget & remember addition
|
||||
@ -166,7 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
|
||||
if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul,
|
||||
&state_forget_mul, OperatorType::kMul,
|
||||
&state_remember_mul)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const string prev_state = state_forget_mul->inputs[0];
|
||||
|
||||
@ -175,7 +177,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
|
||||
if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone,
|
||||
nullptr, OperatorType::kLogistic,
|
||||
&state_forget_sig)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// State remember gate
|
||||
@ -183,40 +185,40 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
|
||||
if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic,
|
||||
&state_remember_sig, OperatorType::kTanh,
|
||||
&state_info_tanh)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// State remember "information" activation function
|
||||
Operator* fc_output_split;
|
||||
if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit,
|
||||
&fc_output_split)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// State remember gate activation function
|
||||
Operator* tmp;
|
||||
if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit,
|
||||
&tmp) ||
|
||||
(tmp != fc_output_split)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// State forget gate activation function
|
||||
if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit,
|
||||
&tmp) ||
|
||||
(tmp != fc_output_split)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Fully connected output activation function
|
||||
if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit,
|
||||
&tmp) ||
|
||||
(tmp != fc_output_split)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Fully connected output split
|
||||
Operator* fully_connected;
|
||||
if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone,
|
||||
nullptr, OperatorType::kFullyConnected,
|
||||
&fully_connected)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Fully connected op
|
||||
@ -225,13 +227,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
|
||||
OperatorType::kConcatenation, &concat_inputs,
|
||||
OperatorType::kNone, nullptr, OperatorType::kNone,
|
||||
nullptr)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (static_cast<FullyConnectedOperator*>(fully_connected)->weights_format !=
|
||||
FullyConnectedWeightsFormat::kDefault) {
|
||||
// Not yet implemented: experimental shuffled weights in fused LSTM cell.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Emplace a new LSTM cell operator
|
||||
@ -300,7 +302,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
|
||||
model->operators.erase(FindOperator(model, *fully_connected));
|
||||
DeleteArrayIfUnused(concat_inputs->outputs[0], model);
|
||||
model->operators.erase(FindOperator(model, *concat_inputs));
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,19 +25,22 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status MergeLstmCellInputs::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
// Find lstm cell.
|
||||
auto op_it = model->operators.begin() + op_index;
|
||||
auto src_op = op_it->get();
|
||||
if (src_op->type != OperatorType::kLstmCell) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Already a compact LstmCell. Do not need to merge cell inputs.
|
||||
const auto* src_lstm_op = static_cast<LstmCellOperator*>(src_op);
|
||||
if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL ||
|
||||
src_lstm_op->inputs.size() != kExtendedLstmInputCount) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Identify prev_activ_input, prev_state_input as required Op inputs,
|
||||
@ -45,12 +48,12 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
|
||||
string prev_activ_input;
|
||||
if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor],
|
||||
&prev_activ_input)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
string prev_state_input;
|
||||
if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor],
|
||||
&prev_state_input)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Get LstmCell's cell, input, output size.
|
||||
@ -184,7 +187,8 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
|
||||
DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model);
|
||||
model->operators.erase(FindOp(*model, src_op));
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,19 +25,22 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status SplitLstmCellInputs::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
// Find lstm cell.
|
||||
auto op_it = model->operators.begin() + op_index;
|
||||
auto curr_op = op_it->get();
|
||||
if (curr_op->type != OperatorType::kLstmCell) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto* curr_lstm_op = static_cast<LstmCellOperator*>(curr_op);
|
||||
// Already an extended LstmCell. Do not need to split cell inputs.
|
||||
if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC ||
|
||||
curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Make sure the WEIGHTS_INPUT and BIASES_INPUT are constant arrays,
|
||||
@ -46,13 +49,13 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
|
||||
*model, curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]) ||
|
||||
!IsConstantParameterArray(
|
||||
*model, curr_op->inputs[LstmCellOperator::BIASES_INPUT])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Make sure propagate_fixed_sizes has defined the size of the output.
|
||||
if (!model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT])
|
||||
.has_shape()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc).
|
||||
@ -168,7 +171,8 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
|
||||
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
|
||||
model->operators.erase(FindOp(*model, curr_op));
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -43,13 +43,15 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status IdentifyPRelu::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto add_op_it = model->operators.begin() + op_index;
|
||||
const auto* add_op = add_op_it->get();
|
||||
if (add_op == nullptr || add_op->type != OperatorType::kAdd ||
|
||||
add_op->inputs.size() != 2 ||
|
||||
add_op->fused_activation_function != FusedActivationFunctionType::kNone) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]);
|
||||
@ -57,7 +59,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
|
||||
relu_input_op->inputs.size() != 1 ||
|
||||
relu_input_op->fused_activation_function !=
|
||||
FusedActivationFunctionType::kNone) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// TODO(ycling): Both Add and Mul are commutative. Support the case where
|
||||
@ -66,7 +68,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
|
||||
if (mul_op == nullptr || mul_op->type != OperatorType::kMul ||
|
||||
mul_op->inputs.size() != 2 ||
|
||||
mul_op->fused_activation_function != FusedActivationFunctionType::kNone) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto neg_alpha_tensor_name = mul_op->inputs[0];
|
||||
@ -75,7 +77,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
if (relu_neg_input_op == nullptr ||
|
||||
relu_neg_input_op->inputs.size() != 1) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const Operator* final_input_op;
|
||||
@ -92,13 +94,13 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
|
||||
relu_neg_input_op->type != OperatorType::kRelu ||
|
||||
relu_neg_input_op->fused_activation_function !=
|
||||
FusedActivationFunctionType::kNone) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
final_input_op = neg_input_op;
|
||||
}
|
||||
|
||||
if (relu_input_op->inputs[0] != final_input_op->inputs[0]) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto input_tensor_name = relu_input_op->inputs[0];
|
||||
@ -128,7 +130,8 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
|
||||
// intermediate tensors aren't used by other ops, those will be removed by
|
||||
// other graph transformation rules.
|
||||
model->operators.erase(FindOp(*model, add_op));
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -56,13 +56,15 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status IdentifyRelu1::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
// Follow sequences of min+max and max+min. First get the leading op.
|
||||
const auto op_it = model->operators.begin() + op_index;
|
||||
const auto* op_0 = op_it->get();
|
||||
if (op_0->type != OperatorType::kMinimum &&
|
||||
op_0->type != OperatorType::kMaximum) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Get the paired op and ensure it's the counter to the first.
|
||||
@ -71,17 +73,17 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
|
||||
(op_1->type != OperatorType::kMinimum &&
|
||||
op_1->type != OperatorType::kMaximum) ||
|
||||
op_0->type == op_1->type) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1;
|
||||
const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1;
|
||||
|
||||
if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Get the original input to the min+max pair.
|
||||
@ -90,7 +92,7 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
|
||||
int max_scalar_input_index =
|
||||
GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f);
|
||||
if (min_scalar_input_index == -1 || max_scalar_input_index == -1) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
int op_0_scalar_input_index =
|
||||
op_0 == min_op ? min_scalar_input_index : max_scalar_input_index;
|
||||
@ -111,7 +113,8 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
|
||||
model->operators.erase(FindOperator(model, op_0));
|
||||
model->operators.erase(FindOperator(model, op_1));
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -97,7 +97,10 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status MakeInitialDequantizeOperator::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
// This is effectively a transformation applied to edges. We iterate over the
|
||||
// specified node (op) and proceed for input edges.
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
@ -114,7 +117,8 @@ bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return change_made;
|
||||
*modified = change_made;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -102,18 +102,19 @@ std::vector<int32> ReshapeToTranspose(const Model& model,
|
||||
// to be merged if the reshape does not affect memory ordering and does not
|
||||
// affects the number of dimensions. This only occurs when only unary dimensions
|
||||
// are shifting position.
|
||||
bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
|
||||
std::size_t op_index) {
|
||||
::tensorflow::Status MergeReshapeIntoPrecedingTranspose::Run(
|
||||
Model* model, std::size_t op_index, bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
|
||||
it->get(), OperatorType::kReshape);
|
||||
|
||||
if (reshape_op == nullptr) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const string intermediate_name = reshape_op->inputs[0];
|
||||
@ -121,13 +122,13 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
|
||||
|
||||
// Guarantee the input is only consume by the reshape.
|
||||
if (CountOpsWithInput(*model, intermediate_name) != 1) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Check for the parent operator.
|
||||
const auto& transpose_it = FindOpWithOutput(*model, intermediate_name);
|
||||
if (transpose_it == model->operators.end()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Find the parent operator and guarantee it is a transpose.
|
||||
@ -135,16 +136,16 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
|
||||
transpose_it->get(), OperatorType::kTranspose);
|
||||
|
||||
if (transpose_op == nullptr) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
|
||||
false /*allow_extra_unary_dimensions*/)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Check that the intermediate is not an output array.
|
||||
@ -153,7 +154,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
|
||||
"Cannot fuse %s and %s as it would invalidate the transpose "
|
||||
"output array.",
|
||||
LogName(*transpose_op), LogName(*reshape_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Merging operations %s and %s", LogName(*transpose_op),
|
||||
@ -172,7 +173,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
|
||||
|
||||
// Remove the reshape as passthrough operation.
|
||||
if (!RemoveTrivialPassthroughOp(this, model, op_index)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Update transpose_op's constant buffer to contain the new permutation.
|
||||
@ -184,7 +185,8 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
|
||||
// transpose_ops's shape will likely has changed.
|
||||
model->GetArray(transpose_op->outputs[0]).clear_shape();
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -54,7 +54,10 @@ bool IsTailOfShape(const Shape& tail, const Shape& shape) {
|
||||
//
|
||||
// Note we are testing for one particular case of a broader set of possible
|
||||
// binary-reshape op transformations. This transformation could be generalized.
|
||||
bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto binary_it = model->operators.begin() + op_index;
|
||||
Operator* binary_op = binary_it->get();
|
||||
if (binary_op->type != OperatorType::kAdd &&
|
||||
@ -69,7 +72,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
binary_op->type != OperatorType::kLessEqual &&
|
||||
binary_op->type != OperatorType::kGreater &&
|
||||
binary_op->type != OperatorType::kGreaterEqual) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// BINARY OP INPUT CHECKS
|
||||
@ -81,11 +84,11 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
if (!input_is_const[0] && !input_is_const[1]) {
|
||||
// To limit our scope, we require one constant input. Though there's no
|
||||
// reason this transformation wouldn't work with all variable inputs.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (input_is_const[0] && input_is_const[1]) {
|
||||
// Both inputs are constants. Leave this for constants propagation.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const int constant_input_idx = input_is_const[0] ? 0 : 1;
|
||||
const int variable_input_idx = input_is_const[0] ? 1 : 0;
|
||||
@ -98,13 +101,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF(
|
||||
"Not moving %s because it's non-constant input shape is not resolved.",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!IsTailOfShape(
|
||||
model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
|
||||
model->GetArray(binary_op->inputs[variable_input_idx]).shape())) {
|
||||
// Constant array shape must be the latter part of the variable shape.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// RESHAPE OP CHECKS
|
||||
@ -113,13 +116,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
if (reshape_it == model->operators.end()) {
|
||||
AddMessageF("Not moving %s because it's variable input is not connected.",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
Operator* reshape_op = reshape_it->get();
|
||||
if (reshape_op->type != OperatorType::kReshape) {
|
||||
AddMessageF("Not moving %s because the preceding %s is not a reshape op",
|
||||
LogName(*binary_op), LogName(*reshape_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]);
|
||||
if (!reshape_input_array.has_shape()) {
|
||||
@ -127,14 +130,14 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
"Not moving %s because it's non-constant input shape is not resolved "
|
||||
"yet",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!IsTailOfShape(
|
||||
model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
|
||||
model->GetArray(reshape_op->outputs[0]).shape())) {
|
||||
// Constant array shape must be the latter part of the binary op output
|
||||
// shape.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// EXTRA CHECKS ON CONNECTING ARRAY
|
||||
@ -143,7 +146,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF(
|
||||
"Not moving %s because the output of reshape op %s is an output op.",
|
||||
LogName(*binary_op), LogName(*reshape_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
int count_ops_consuming_output =
|
||||
@ -154,7 +157,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
"Not moving %s because the output of reshape op %s is consumed by "
|
||||
"another op",
|
||||
LogName(*binary_op), LogName(*reshape_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// SWAP ORDER OF BINARY AND RESHAPE OPS
|
||||
@ -172,7 +175,8 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
|
||||
// Clear binary output shape so it will be re-propagated
|
||||
model->GetArray(binary_op->outputs[0]).clear_shape();
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -26,20 +26,21 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool PropagateActivationFunctionIntoConstants::Run(Model* model,
|
||||
std::size_t op_index) {
|
||||
::tensorflow::Status PropagateActivationFunctionIntoConstants::Run(
|
||||
Model* model, std::size_t op_index, bool* modified) {
|
||||
*modified = false;
|
||||
const auto ac_it = model->operators.begin() + op_index;
|
||||
const auto* ac_op = ac_it->get();
|
||||
if (ac_op->type != OperatorType::kRelu6 &&
|
||||
ac_op->type != OperatorType::kRelu1 &&
|
||||
ac_op->type != OperatorType::kRelu) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Find the op producing the array passed to this activation function.
|
||||
auto* src_op = GetOpWithOutput(*model, ac_op->inputs[0]);
|
||||
if (!src_op) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Ensure the src_op is not used without the activation function applied.
|
||||
@ -57,7 +58,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
|
||||
src_op_input = src_op->inputs[0];
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(src_op->outputs[0], ac_op->inputs[0]);
|
||||
|
||||
@ -69,7 +70,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
|
||||
"Not propagating activation function %s into %s:%s because it is not "
|
||||
"constant",
|
||||
LogName(*ac_op), LogName(*src_op), src_op_input);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Get the array we'll be working with and ensure it's a compatible type.
|
||||
@ -79,7 +80,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
|
||||
"Not propagating activation function %s into %s:%s because it is "
|
||||
"non-float data",
|
||||
LogName(*ac_op), LogName(*src_op), src_op_input);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto& const_array_data =
|
||||
const_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
|
||||
@ -108,14 +109,15 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported activation function " << LogName(*ac_op);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const_array_data[i] = new_value;
|
||||
}
|
||||
|
||||
AddMessageF("Propagated activation function %s into %s:%s", LogName(*ac_op),
|
||||
LogName(*src_op), src_op_input);
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -32,7 +32,10 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status PropagateArrayDataTypes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
|
||||
@ -40,7 +43,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
|
||||
for (const auto& input : op->inputs) {
|
||||
if (!model->IsOptionalArray(input) &&
|
||||
model->GetArray(input).data_type == ArrayDataType::kNone) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
// Record data types of output before processing, so we can see at the
|
||||
@ -131,7 +134,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
|
||||
auto* rand_op = static_cast<RandomUniformOperator*>(op);
|
||||
// The output type of RandomUniform is specified with an attribute
|
||||
if (rand_op->dtype == ArrayDataType::kNone) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
SetDataTypeForAllOutputs(model, op, rand_op->dtype);
|
||||
@ -153,7 +156,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
|
||||
// This can make unsupported_op->output_data_types have more elements than
|
||||
// op->outputs.
|
||||
if (unsupported_op->output_data_types.size() < op->outputs.size()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
for (int i = 0; i < op->outputs.size(); ++i) {
|
||||
const string& output = op->outputs[i];
|
||||
@ -164,7 +167,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
case OperatorType::kExpandDims: {
|
||||
// Yield on ExpandDim until it is converted to Reshape
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
case OperatorType::kSelect: {
|
||||
// Select produces outputs with the same type as their 2nd input
|
||||
@ -248,10 +251,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
|
||||
// Return true if any output data type changed, false if none changed.
|
||||
for (const auto& output : op->outputs) {
|
||||
if (old_output_data_types[output] != model->GetArray(output).data_type) {
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -39,7 +39,10 @@ bool SupportsMinMax(const Array& array) {
|
||||
// When provided a set of min/max values for uint8 arrays this will rescale
|
||||
// the values for other data types as required and preserving the floating point
|
||||
// range within the new type.
|
||||
bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status PropagateDefaultMinMax::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
const auto* op = it->get();
|
||||
|
||||
@ -61,7 +64,8 @@ bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
}
|
||||
|
||||
return did_change;
|
||||
*modified = did_change;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Sets the min/max on the given array, adjusting the reference_minmax for the
|
||||
|
@ -277,11 +277,14 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation,
|
||||
// nice logging and integration with the graphviz video dumping mode.
|
||||
// In general you should not copy this style of transformation and stick to
|
||||
// local-only changes as seen in the other transformations.
|
||||
bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status PropagateFakeQuantNumBits::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
if (op->type != OperatorType::kFakeQuant) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
|
||||
|
||||
@ -290,7 +293,7 @@ bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
|
||||
&quantized_data_type)) {
|
||||
AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring",
|
||||
LogName(*op), fakequant_op->num_bits);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto& final_minmax = *fakequant_op->minmax;
|
||||
|
||||
@ -311,7 +314,8 @@ bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
|
||||
did_change |=
|
||||
RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type);
|
||||
|
||||
return did_change;
|
||||
*modified = did_change;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -1622,7 +1622,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
|
||||
|
||||
} // namespace
|
||||
|
||||
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
std::unordered_map<string, std::vector<int>> old_output_dims;
|
||||
@ -1836,7 +1839,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
|
||||
static_cast<TensorFlowUnsupportedOperator*>(op);
|
||||
// Attribute can be not specified, ignore it.
|
||||
if (unsupported_op->output_shapes.size() < op->outputs.size()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
for (int i = 0; i < op->outputs.size(); ++i) {
|
||||
const string& output = op->outputs[i];
|
||||
@ -1886,10 +1889,11 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
|
||||
(old_output_dims[output] != model->GetArray(output).shape().dims())) {
|
||||
AddMessageF("Set shape of %s to [%s]", output,
|
||||
absl::StrJoin(model->GetArray(output).shape().dims(), ","));
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -439,7 +439,9 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool Quantize::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status Quantize::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
// Our general "quantization" graph transformation consists in replacing
|
||||
// QuantizedInputArrays[] ->
|
||||
// DequantizeOperators[] ->
|
||||
@ -460,7 +462,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
|
||||
auto& op = *model->operators[op_index];
|
||||
if (op.type == OperatorType::kDequantize ||
|
||||
op.type == OperatorType::kFakeQuant) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Our assumption here is that the input arrays are already quantized -
|
||||
@ -497,7 +499,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
|
||||
if (!array.minmax && !array.buffer) {
|
||||
LOG(ERROR) << "Can't quantize input array " << input
|
||||
<< " because it lacks min/max info";
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* other_op = GetOpWithOutput(*model, input);
|
||||
if (other_op && other_op->type != OperatorType::kDequantize) {
|
||||
@ -507,7 +509,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
|
||||
"which means that we should yield and let other ops "
|
||||
"get quantized first",
|
||||
LogName(op), input);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -672,7 +674,8 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
*modified = changed;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -51,18 +51,19 @@ bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model,
|
||||
|
||||
} // end namespace
|
||||
|
||||
bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
|
||||
std::size_t op_index) {
|
||||
::tensorflow::Status ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(
|
||||
Model* model, std::size_t op_index, bool* modified) {
|
||||
*modified = false;
|
||||
const auto fakequant_it = model->operators.begin() + op_index;
|
||||
auto* fakequant_base_op = fakequant_it->get();
|
||||
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* fq_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
|
||||
|
||||
if (!fq_op->minmax) {
|
||||
// Need to be resolved first by ResolveFakeQuantArgsFromVars.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// At this point, this FakeQuantOperator should have a MinMax
|
||||
@ -74,7 +75,8 @@ bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
|
||||
bool changed = false;
|
||||
changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]);
|
||||
changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
|
||||
return changed;
|
||||
*modified = changed;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,11 +25,14 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveFinalDequantizeOp::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto dequantize_it = model->operators.begin() + op_index;
|
||||
const auto* dequantize_op = dequantize_it->get();
|
||||
if (dequantize_op->type != OperatorType::kDequantize) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto& output = dequantize_op->outputs[0];
|
||||
// We can remove any dequantize op whose output is not consumed by
|
||||
@ -38,7 +41,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
|
||||
// in the middle of the graph might be designated as an output
|
||||
// array.
|
||||
if (CountOpsWithInput(*model, output)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// If one of the model's output arrays was actually the Dequantize op's
|
||||
@ -53,7 +56,8 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF("Removed final %s", LogName(*dequantize_op));
|
||||
model->EraseArray(output);
|
||||
model->operators.erase(dequantize_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -23,11 +23,14 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTensorFlowAssert::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto assert_it = model->operators.begin() + op_index;
|
||||
const auto* assert_op = assert_it->get();
|
||||
if (assert_op->type != OperatorType::kAssert) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
@ -54,7 +57,8 @@ bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// That's it. We can stop here, no need to duplicate the work that
|
||||
// RemoveUnusedOp will do removing this now-unused node.
|
||||
return changed;
|
||||
*modified = changed;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,14 +25,18 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTensorFlowIdentity::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto passthru_it = model->operators.begin() + op_index;
|
||||
const auto* passthru_op = passthru_it->get();
|
||||
if (passthru_op->type != OperatorType::kIdentity) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -46,14 +46,17 @@ bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
|
||||
// For example, an Add operator is trivial if
|
||||
// one of its operands is constant 0, a Mul operator is trivial
|
||||
// if one of its operands is constant 1, etc.
|
||||
bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTrivialBinaryOperator::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto binary_it = model->operators.begin() + op_index;
|
||||
auto* binary_op = binary_it->get();
|
||||
if (binary_op->type != OperatorType::kAdd &&
|
||||
binary_op->type != OperatorType::kMul &&
|
||||
binary_op->type != OperatorType::kSub &&
|
||||
binary_op->type != OperatorType::kDiv) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(binary_op->inputs.size(), 2);
|
||||
@ -66,12 +69,12 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
};
|
||||
if (!is_input_constant[0] && !is_input_constant[1]) {
|
||||
// Neither input is constant, so nothing we can resolve here.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (is_input_constant[0] && is_input_constant[1]) {
|
||||
// Both inputs are constants. That's a job for constants
|
||||
// propagation, not for us to handle here.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
|
||||
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
|
||||
@ -84,7 +87,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
const auto& input_array_1 = model->GetArray(binary_op->inputs[1]);
|
||||
if (!input_array_0.has_shape() || !input_array_1.has_shape()) {
|
||||
// Both input shapes must be known.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (input_array_0.shape().dimensions_count() ==
|
||||
input_array_1.shape().dimensions_count() &&
|
||||
@ -94,7 +97,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
"(lhs %s, rhs %s)",
|
||||
LogName(*binary_op), ShapeToString(input_array_0.shape()),
|
||||
ShapeToString(input_array_1.shape()));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Now check if the constant operand makes this binary
|
||||
@ -103,7 +106,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
model->GetArray(binary_op->inputs[index_of_constant_input]);
|
||||
// For now, we only handle floats here.
|
||||
if (constant_input_array.data_type != ArrayDataType::kFloat) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto& constant_input_float_data =
|
||||
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
|
||||
@ -121,12 +124,13 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
|
||||
if (!is_trivial) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Now we know that this node is trivial, so we can remove it.
|
||||
AddMessageF("Removing trivial %s", LogName(*binary_op));
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,16 +25,20 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTrivialConcatenation::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto concat_it = model->operators.begin() + op_index;
|
||||
auto* concat_op = concat_it->get();
|
||||
if (concat_op->type != OperatorType::kConcatenation) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (concat_op->inputs.size() != 1) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,7 +25,10 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTrivialConcatenationInput::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
// TensorFlow allows Concatenation nodes to have 0-D inputs,
|
||||
// and they are then treated as empty i.e. omitted from concatenation,
|
||||
// in violation of the notion that 0-D is equivalent to 1x1x1x1.
|
||||
@ -36,7 +39,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
|
||||
const auto concat_it = model->operators.begin() + op_index;
|
||||
auto* concat_op = concat_it->get();
|
||||
if (concat_op->type != OperatorType::kConcatenation) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
std::vector<string> trivial_inputs;
|
||||
std::vector<string> nontrivial_inputs;
|
||||
@ -52,7 +55,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
|
||||
if (trivial_inputs.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Drop trivial inputs.
|
||||
@ -63,7 +66,8 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
}
|
||||
concat_op->inputs = nontrivial_inputs;
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -64,23 +64,27 @@ bool IsFakeQuantTrivial(GraphTransformation* transformation, const Model& model,
|
||||
} // namespace
|
||||
|
||||
// Removes FakeQuant ops that are trivial (have no effect, are redundant, etc).
|
||||
bool RemoveTrivialFakeQuant::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTrivialFakeQuant::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto op_it = model->operators.begin() + op_index;
|
||||
auto* op = op_it->get();
|
||||
if (op->type != OperatorType::kFakeQuant) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
|
||||
|
||||
if (!IsFakeQuantTrivial(this, *model, *fakequant_op)) {
|
||||
AddMessageF("%s is not trivial", LogName(*fakequant_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Removing trivial %s", LogName(*fakequant_op));
|
||||
|
||||
CHECK_EQ(fakequant_op->inputs.size(), 1);
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -94,12 +94,13 @@ bool IsTrivialFusedActivationFunc(
|
||||
// Attempts to remove both fused and unfused activation functions if the
|
||||
// quantization params indicate that the representable values fall inside the
|
||||
// activation range.
|
||||
bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
|
||||
std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTrivialQuantizedActivationFunc::Run(
|
||||
Model* model, std::size_t op_index, bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
if (op->inputs.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (IsTrivialUnfusedActivationFunc(this, *model, op->type, op->inputs[0])) {
|
||||
@ -107,7 +108,8 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
|
||||
"Removing trivial unfused activation function %s because the input "
|
||||
"minmax imply at least as tight a clamp anyway.",
|
||||
LogName(*op));
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (IsTrivialFusedActivationFunc(this, *model, op->fused_activation_function,
|
||||
op->outputs[0])) {
|
||||
@ -117,9 +119,10 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
|
||||
"because the output quantization parameters imply at least as tight "
|
||||
"a clamp anyway.",
|
||||
LogName(*op));
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -69,22 +69,26 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
|
||||
|
||||
// Attempts to remove min/max functions if the quantization params indicate that
|
||||
// the representable values fall inside the clip range.
|
||||
bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTrivialQuantizedMinMax::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
if ((op->type != OperatorType::kMinimum &&
|
||||
op->type != OperatorType::kMaximum) ||
|
||||
op->inputs.size() != 2) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (IsTrivialMinMax(this, *model, op->type, op->inputs[0], op->inputs[1])) {
|
||||
AddMessageF(
|
||||
"Removing trivial min/max %s because the quantization parameters imply "
|
||||
"at least as tight a clamp anyway.",
|
||||
LogName(*op));
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -81,22 +81,26 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTrivialReshape::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto reshape_it = model->operators.begin() + op_index;
|
||||
auto* reshape_op = reshape_it->get();
|
||||
if (reshape_op->type != OperatorType::kReshape) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!IsReshapeTrivial(*model, *reshape_op, this)) {
|
||||
AddMessageF("%s is not trivial", LogName(*reshape_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Removing trivial %s", LogName(*reshape_op));
|
||||
|
||||
CHECK_EQ(reshape_op->inputs.size(), 2);
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -49,21 +49,24 @@ bool IsSliceTrivial(const Model& model, const Operator& op,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RemoveTrivialSlice::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveTrivialSlice::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto reshape_it = model->operators.begin() + op_index;
|
||||
auto* slice_op = reshape_it->get();
|
||||
if (slice_op->type != OperatorType::kSlice) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!IsSliceTrivial(*model, *slice_op, this)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Removing trivial %s", LogName(*slice_op));
|
||||
|
||||
CHECK_EQ(slice_op->inputs.size(), 3);
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,7 +25,9 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status RemoveUnusedOp::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
const auto* op = it->get();
|
||||
|
||||
@ -58,7 +60,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
for (const string& output_array : model->flags.output_arrays()) {
|
||||
if (output == output_array) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||
@ -67,19 +69,19 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
|
||||
if (!IsDiscardableArray(*model, rnn_state.back_edge_source_array()) ||
|
||||
!IsDiscardableArray(*model, rnn_state.state_array()) ||
|
||||
CountOpsWithInput(*model, rnn_state.state_array())) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (CountOpsWithInput(*model, output)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
if (op->unresolved_outputs) {
|
||||
AddMessageF("Not discarding %s because it has unresolved outputs.",
|
||||
LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Discarding %s because none of its outputs is used.",
|
||||
@ -105,7 +107,8 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
}
|
||||
model->operators.erase(it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -63,29 +63,32 @@ bool IsMoveOperator(OperatorType optype) {
|
||||
|
||||
// Swap elementwise operators such that all value operators occur before all
|
||||
// element move operators, e.g. negation then transpose.
|
||||
bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ReorderElementwiseUnary::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto element_op_it = model->operators.begin() + op_index;
|
||||
std::unique_ptr<Operator>& element_op = *element_op_it;
|
||||
if (!IsElementwiseOperator(element_op->type)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const string intermediate_name = element_op->inputs[0];
|
||||
auto it = FindOpWithOutput(*model, intermediate_name);
|
||||
if (it == model->operators.end()) {
|
||||
AddMessageF("No preceding operator");
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
std::unique_ptr<Operator>& move_op = *it;
|
||||
if (!IsMoveOperator(move_op->type)) {
|
||||
AddMessageF("Preceding operator is not a move operator");
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (CountOpsWithInput(*model, intermediate_name) != 1) {
|
||||
AddMessageF("Input %s used elsewhere", intermediate_name);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Check that the intermediate is discardable.
|
||||
@ -94,7 +97,7 @@ bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
|
||||
"Cannot swap elementwise as it would invalidate %s which is "
|
||||
"an output array.",
|
||||
intermediate_name);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// op->inputs may change so we need to keep a value by copy.
|
||||
@ -147,7 +150,8 @@ bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
|
||||
// Swap the order of the operators.
|
||||
element_op.swap(move_op);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -101,37 +101,40 @@ std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
|
||||
|
||||
// Swaps reshape-transpose to transpose-reshape whenever possible. This is
|
||||
// possible when the reshape does not affect memory ordering.
|
||||
bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ReorderReshapeTranspose::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto transpose_it = model->operators.begin() + op_index;
|
||||
|
||||
TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
|
||||
transpose_it->get(), OperatorType::kTranspose);
|
||||
|
||||
if (transpose_op == nullptr) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
|
||||
// Wait for values to propagate.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Find the operator that produces the transpose op.
|
||||
auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]);
|
||||
if (reshape_it == model->operators.end()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
TensorFlowReshapeOperator* reshape_op =
|
||||
ConvertOperator<TensorFlowReshapeOperator*>(reshape_it->get(),
|
||||
OperatorType::kReshape);
|
||||
if (reshape_op == nullptr) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Ignore if the reshape is uninitialized.
|
||||
if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Need to copy to keep static if permutated.
|
||||
@ -142,7 +145,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
|
||||
// Intermediate should not be consumed by any other operators.
|
||||
if (CountOpsWithInput(*model, intermediate_name) != 1) {
|
||||
AddMessageF("Input %s used elsewhere", intermediate_name);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Check that the intermediate is not an output array.
|
||||
@ -151,7 +154,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
|
||||
"Cannot reorder reshape-transpose as it would invalidate %s which is "
|
||||
"an output array.",
|
||||
intermediate_name);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Get the arrays.
|
||||
@ -173,7 +176,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
|
||||
// dimensions then it can be moved between the transpose.
|
||||
if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
|
||||
true /*allow_extra_unary_dims*/)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!IsDiscardableArray(*model, output_name)) {
|
||||
@ -242,7 +245,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
|
||||
// Swap the order of the operators.
|
||||
transpose_it->swap(*reshape_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,10 +25,13 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveBatchNormalization::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto bn_it = model->operators.begin() + op_index;
|
||||
if (bn_it->get()->type != OperatorType::kBatchNormalization) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* bn_op =
|
||||
static_cast<const BatchNormalizationOperator*>(bn_it->get());
|
||||
@ -53,7 +56,7 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
|
||||
// so we need to exit early if these buffers don't exist (i.e. if the params
|
||||
// haven't yet been resolved as constants).
|
||||
if (!mean_array.buffer || !multiplier_array.buffer || !offset_array.buffer) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Create the new Mul, Add operators
|
||||
@ -142,7 +145,8 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
|
||||
DCHECK_EQ(bn_it->get(), bn_op);
|
||||
model->operators.erase(bn_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,31 +24,35 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveBatchToSpaceNDAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto op_it = model->operators.begin() + op_index;
|
||||
if (op_it->get()->type != OperatorType::kBatchToSpaceND) return false;
|
||||
if (op_it->get()->type != OperatorType::kBatchToSpaceND)
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
auto* op = static_cast<BatchToSpaceNDOperator*>(op_it->get());
|
||||
|
||||
// The attributes are resolved only when the 3 attributes (block_shape,
|
||||
// before_crops, after_crops) are all constant.
|
||||
if (!op->block_shape.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(op->inputs.size(), 3);
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1]) ||
|
||||
!IsConstantParameterArray(*model, op->inputs[2]))
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
// Handle crops
|
||||
const auto& crops_array = model->GetArray(op->inputs[2]);
|
||||
if (!crops_array.has_shape()) return false;
|
||||
if (!crops_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
const std::vector<int>& crops_dims = crops_array.shape().dims();
|
||||
if (crops_dims.size() != 2) {
|
||||
// Code only handles crops of 2 dimensions. Perhaps another transformation
|
||||
// will delete this op.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const std::vector<int>& crops_buffer =
|
||||
crops_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
@ -59,7 +63,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Handle block_shape
|
||||
const auto& block_shape_array = model->GetArray(op->inputs[1]);
|
||||
if (!block_shape_array.has_shape()) return false;
|
||||
if (!block_shape_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
|
||||
CHECK_EQ(block_shape_dims.size(), 1);
|
||||
const std::vector<int>& block_shape_buffer =
|
||||
@ -68,7 +72,8 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
|
||||
op->block_shape.push_back(block_shape_buffer[i]);
|
||||
}
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -188,7 +188,10 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto binary_it = model->operators.begin() + op_index;
|
||||
const auto* binary_op = binary_it->get();
|
||||
// Test for binary ops of types that we know how to resolve
|
||||
@ -204,7 +207,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
binary_op->type != OperatorType::kLessEqual &&
|
||||
binary_op->type != OperatorType::kGreater &&
|
||||
binary_op->type != OperatorType::kGreaterEqual) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(binary_op->inputs.size(), 2);
|
||||
|
||||
@ -212,13 +215,13 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
const auto& input1_array = model->GetArray(binary_op->inputs[1]);
|
||||
// Check if both inputs are constant parameters.
|
||||
if (!input0_array.buffer || !input1_array.buffer) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
auto& output_array = model->GetArray(binary_op->outputs[0]);
|
||||
// Yield until the output array dims have been resolved.
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// At the moment we don't want to care about fused activation functions.
|
||||
@ -229,7 +232,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF(
|
||||
"Not resolving constant %s because it has a fused activation function",
|
||||
LogName(*binary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Check that input data types agree.
|
||||
@ -253,7 +256,8 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF("Resolved constant %s to the equivalent constant array",
|
||||
LogName(*binary_op));
|
||||
model->operators.erase(binary_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -135,11 +135,14 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation,
|
||||
} // namespace
|
||||
|
||||
// Resolves the concatenation operator if all its inputs are constant arrays.
|
||||
bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantConcatenation::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto concat_it = model->operators.begin() + op_index;
|
||||
const auto* concat_base_op = concat_it->get();
|
||||
if (concat_base_op->type != OperatorType::kConcatenation) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* concat_op =
|
||||
static_cast<const ConcatenationOperator*>(concat_base_op);
|
||||
@ -149,11 +152,15 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
|
||||
// We also make sure the shapes of the input arrays are known and they are
|
||||
// all discardable.
|
||||
const Operator* input_op = GetOpWithOutput(*model, input_name);
|
||||
if (input_op) return false;
|
||||
if (!IsConstantParameterArray(*model, input_name)) return false;
|
||||
if (!model->GetArray(input_name).has_shape()) return false;
|
||||
if (model->GetArray(input_name).quantization_params) return false;
|
||||
if (!IsDiscardableArray(*model, input_name)) return false;
|
||||
if (input_op) return ::tensorflow::Status::OK();
|
||||
if (!IsConstantParameterArray(*model, input_name))
|
||||
return ::tensorflow::Status::OK();
|
||||
if (!model->GetArray(input_name).has_shape())
|
||||
return ::tensorflow::Status::OK();
|
||||
if (model->GetArray(input_name).quantization_params)
|
||||
return ::tensorflow::Status::OK();
|
||||
if (!IsDiscardableArray(*model, input_name))
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const int concatenation_axis = concat_op->axis;
|
||||
@ -205,7 +212,8 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Remove concatenate operator.
|
||||
model->operators.erase(concat_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -59,11 +59,14 @@ void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type,
|
||||
}
|
||||
}
|
||||
|
||||
bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantFakeQuant::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto fakequant_it = model->operators.begin() + op_index;
|
||||
const auto* fakequant_base_op = fakequant_it->get();
|
||||
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto* fakequant_op =
|
||||
@ -71,12 +74,12 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Yield until the fakequant MinMax has been resolved.
|
||||
if (!fakequant_op->minmax) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// This transformation only applies when the input array is constant.
|
||||
if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
|
||||
@ -87,7 +90,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
|
||||
if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
|
||||
&quantized_data_type)) {
|
||||
AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Resolving constant %s", LogName(*fakequant_op));
|
||||
@ -136,7 +139,8 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
model->operators.erase(fakequant_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -41,11 +41,14 @@ bool ComputeFillArray(Model* model, FillOperator* op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantFill::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto fill_it = model->operators.begin() + op_index;
|
||||
auto* base_op = fill_it->get();
|
||||
if (base_op->type != OperatorType::kFill) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* op = static_cast<FillOperator*>(base_op);
|
||||
|
||||
@ -55,44 +58,44 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& val_array = model->GetArray(op->inputs[1]);
|
||||
if (!val_array.has_shape()) {
|
||||
// Yield until the value shape has been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1])) {
|
||||
// Yield until the value is constant.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1);
|
||||
|
||||
switch (output_array.data_type) {
|
||||
case ArrayDataType::kFloat:
|
||||
if (!ComputeFillArray<ArrayDataType::kFloat>(model, op)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
case ArrayDataType::kUint8:
|
||||
if (!ComputeFillArray<ArrayDataType::kUint8>(model, op)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
case ArrayDataType::kInt32:
|
||||
if (!ComputeFillArray<ArrayDataType::kInt32>(model, op)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
case ArrayDataType::kInt64:
|
||||
if (!ComputeFillArray<ArrayDataType::kInt64>(model, op)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@ -114,7 +117,8 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
|
||||
// Erase the operator
|
||||
model->operators.erase(fill_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -61,11 +61,14 @@ inline void Gather(const Array& input_array, int input_rank,
|
||||
// Resolves a constant Gather operation.
|
||||
// This simply performs the gather and produces the output array with the
|
||||
// appropriate values.
|
||||
bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantGather::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
const auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kGather) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* op = static_cast<const GatherOperator*>(base_op);
|
||||
|
||||
@ -74,28 +77,28 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!op->axis) {
|
||||
// Yield until axis has been set by ResolveGatherAttributes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (op->axis.value() != 0) {
|
||||
// Only handling axis=0 for now.
|
||||
AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op),
|
||||
op->axis.value());
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// We require constant inputs.
|
||||
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
|
||||
!IsConstantParameterArray(*model, op->inputs[1])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const Array& input_array = model->GetArray(op->inputs[0]);
|
||||
const Array& coords_array = model->GetArray(op->inputs[1]);
|
||||
@ -142,7 +145,8 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Erase the operator.
|
||||
model->operators.erase(it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -49,11 +49,14 @@ void Pack(Model* model, PackOperator const& op) {
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantPack::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
const auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kPack) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* op = static_cast<const PackOperator*>(base_op);
|
||||
|
||||
@ -62,18 +65,18 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
for (const auto& input : op->inputs) {
|
||||
if (!IsConstantParameterArray(*model, input)) {
|
||||
// Yield if any input is mutable
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -111,7 +114,8 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Erase the operator
|
||||
model->operators.erase(it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -59,11 +59,14 @@ bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kRandomUniform) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* op = static_cast<RandomUniformOperator*>(base_op);
|
||||
|
||||
@ -73,12 +76,12 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if ((op->seed == 0) && (op->seed2 == 0)) {
|
||||
@ -86,13 +89,13 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
|
||||
<< "\" is truly random (using /dev/random system entropy). "
|
||||
"Therefore, cannot resolve as constant. Set \"seed\" or "
|
||||
"\"seed2\" attr non-zero to fix this";
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
switch (output_array.data_type) {
|
||||
case ArrayDataType::kFloat:
|
||||
if (!ComputeRandomUniformArray<ArrayDataType::kFloat>(model, op)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
// For future support of double or half.
|
||||
@ -110,7 +113,8 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
|
||||
// Erase the operator
|
||||
model->operators.erase(it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -19,11 +19,14 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantRange::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kRange) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* op = static_cast<RangeOperator*>(base_op);
|
||||
|
||||
@ -31,23 +34,23 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
|
||||
const auto& start_array = model->GetArray(op->inputs[0]);
|
||||
if (!start_array.has_shape()) {
|
||||
// Yield until all input dims have been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto& limit_array = model->GetArray(op->inputs[1]);
|
||||
if (!limit_array.has_shape()) {
|
||||
// Yield until all input dims have been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto& delta_array = model->GetArray(op->inputs[2]);
|
||||
if (!delta_array.has_shape()) {
|
||||
// Yield until all input dims have been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
for (const auto& input : op->inputs) {
|
||||
if (!IsConstantParameterArray(*model, input)) {
|
||||
// yield if any input is mutable
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,7 +58,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
|
||||
@ -101,7 +104,8 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
|
||||
// Delete the operator
|
||||
model->operators.erase(it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -22,11 +22,14 @@ limitations under the License.
|
||||
namespace toco {
|
||||
|
||||
// Resolves a constant reshape operation by copying the buffer.
|
||||
bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantReshape::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
const auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kReshape) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op);
|
||||
|
||||
@ -36,17 +39,17 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
|
||||
// We require constant inputs.
|
||||
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
|
||||
!IsConstantParameterArray(*model, op->inputs[1])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const Array& input_array = model->GetArray(op->inputs[0]);
|
||||
@ -54,7 +57,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF("Constant reshape is non-trivial (%s -> %s)",
|
||||
ShapeToString(input_array.shape()),
|
||||
ShapeToString(output_array.shape()));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK(!output_array.buffer);
|
||||
@ -95,7 +98,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported data type: "
|
||||
<< ArrayDataTypeName(input_array.data_type);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
AddMessageF("Resolving constant reshape of %s", LogName(*op));
|
||||
@ -112,7 +115,8 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Erase the operator.
|
||||
model->operators.erase(it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -27,11 +27,14 @@ namespace toco {
|
||||
// This implementation is looking strictly for all-or-nothing on the select
|
||||
// condition. It's possible to enhance this by looking per-element and possibly
|
||||
// producing a Mul op.
|
||||
bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantSelect::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
const auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kSelect) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* op = static_cast<const SelectOperator*>(base_op);
|
||||
|
||||
@ -40,23 +43,23 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// We require the cond input to be constant.
|
||||
if (!IsConstantParameterArray(*model, op->inputs[0])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const Array& cond_array = model->GetArray(op->inputs[0]);
|
||||
CHECK(cond_array.data_type == ArrayDataType::kBool)
|
||||
<< "Only bool conditions are supported";
|
||||
const auto& cond_data = cond_array.GetBuffer<ArrayDataType::kBool>().data;
|
||||
if (cond_data.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Check if the condition is the same for all elements.
|
||||
@ -67,12 +70,14 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
|
||||
"Cannot resolve %s as constant; cond_array has differing "
|
||||
"per-element values",
|
||||
LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Pass-through the selected input.
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
|
||||
*modified =
|
||||
RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -19,29 +19,32 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
const auto* op = it->get();
|
||||
if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been resolved
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& input_array = model->GetArray(op->inputs[0]);
|
||||
if (!input_array.has_shape()) {
|
||||
// Yield until the input array's shape has been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Compute the output
|
||||
@ -65,7 +68,8 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
|
||||
model->operators.erase(it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -86,11 +86,14 @@ bool Slice(SliceOperator const& op, Array const& input_array,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantSlice::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
const auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kSlice) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const SliceOperator* op = static_cast<const SliceOperator*>(base_op);
|
||||
@ -99,49 +102,49 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (op->begin.empty() || op->size.empty()) {
|
||||
// Attributes have not resolved yet.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& input_array = model->GetArray(op->inputs[0]);
|
||||
if (!input_array.has_shape()) {
|
||||
// Yield until the value shape has been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!IsConstantParameterArray(*model, op->inputs[0])) {
|
||||
// Yield until the value is constant.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK(!output_array.buffer);
|
||||
switch (output_array.data_type) {
|
||||
case ArrayDataType::kFloat:
|
||||
if (!Slice<ArrayDataType::kFloat>(*op, input_array, &output_array)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
case ArrayDataType::kUint8:
|
||||
if (!Slice<ArrayDataType::kUint8>(*op, input_array, &output_array)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
case ArrayDataType::kInt32:
|
||||
if (!Slice<ArrayDataType::kInt32>(*op, input_array, &output_array)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
case ArrayDataType::kInt64:
|
||||
if (!Slice<ArrayDataType::kInt64>(*op, input_array, &output_array)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@ -159,7 +162,8 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
|
||||
// Erase the operator
|
||||
model->operators.erase(it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -103,11 +103,14 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
const auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kStridedSlice) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const StridedSliceOperator* op =
|
||||
@ -117,28 +120,28 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (op->start_indices.empty() || op->stop_indices.empty() ||
|
||||
op->strides.empty()) {
|
||||
// Attributes have not resolved yet.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const auto& input_array = model->GetArray(op->inputs[0]);
|
||||
if (!input_array.has_shape()) {
|
||||
// Yield until the value shape has been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!IsConstantParameterArray(*model, op->inputs[0])) {
|
||||
// Yield until the value is constant.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK(!output_array.buffer);
|
||||
@ -164,7 +167,8 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
DeleteOpAndArraysIfUnused(model, it->get());
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -97,11 +97,14 @@ inline void Tile(const Array& input_array, const Array& multiples_array,
|
||||
} // namespace
|
||||
|
||||
// Resolves a constant Tile operation.
|
||||
bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantTile::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
const auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kTile) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* op = static_cast<const TensorFlowTileOperator*>(base_op);
|
||||
|
||||
@ -110,17 +113,17 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// We require constant inputs.
|
||||
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
|
||||
!IsConstantParameterArray(*model, op->inputs[1])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const Array& input_array = model->GetArray(op->inputs[0]);
|
||||
const Array& multiples_array = model->GetArray(op->inputs[1]);
|
||||
@ -159,7 +162,8 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Erase the operator.
|
||||
model->operators.erase(it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -101,11 +101,14 @@ void Transpose(Model* model, const Array& input_array,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantTranspose::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
const auto* base_op = it->get();
|
||||
if (base_op->type != OperatorType::kTranspose) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* op = static_cast<const TransposeOperator*>(base_op);
|
||||
|
||||
@ -114,17 +117,17 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been set by PropagateFixedShapes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// We require constant inputs.
|
||||
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
|
||||
!IsConstantParameterArray(*model, op->inputs[1])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const Array& input_array = model->GetArray(op->inputs[0]);
|
||||
|
||||
@ -132,7 +135,7 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
if (op->perm.empty()) {
|
||||
// Yield until perm has been populated by ResolveTransposeAttributes.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// We currently only support 1-4 dimensions.
|
||||
@ -174,7 +177,8 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// Erase the operator.
|
||||
model->operators.erase(it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -112,7 +112,10 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto unary_it = model->operators.begin() + op_index;
|
||||
const auto* unary_op = unary_it->get();
|
||||
// Test for unary ops of types that we know how to resolve.
|
||||
@ -133,28 +136,28 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
case OperatorType::kRelu:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Check if the input is a constant parameter.
|
||||
if (!IsConstantParameterArray(*model, unary_op->inputs[0])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// if the unary op involves a tensor required by a rnn state, ignore it
|
||||
for (const auto& rnn_state : model->flags.rnn_states()) {
|
||||
if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (unary_op->inputs[0] == rnn_state.state_array()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
auto& output_array = model->GetArray(unary_op->outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output array dims have been resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// At the moment we don't want to care about fused activation functions.
|
||||
@ -166,7 +169,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
"Not resolving constant %s "
|
||||
" because it has a fused activation function",
|
||||
LogName(*unary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// The min-max is only copied for ops that copy data without arithmetic.
|
||||
@ -187,7 +190,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
"Not resolving constant %s because we currently only support casting "
|
||||
"to float",
|
||||
LogName(*unary_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (cast_op->src_data_type != input_array.buffer->type) {
|
||||
AddMessageF(
|
||||
@ -197,7 +200,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
} else {
|
||||
if (input_array.buffer->type != ArrayDataType::kFloat) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
input_float_data = &(input_array.GetBuffer<ArrayDataType::kFloat>().data);
|
||||
}
|
||||
@ -239,7 +242,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs";
|
||||
if (!IsConstantParameterArray(*model, unary_op->inputs[1])) {
|
||||
AddMessageF("Axis input is non-constant");
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto& axis_array = model->GetArray(unary_op->inputs[1]);
|
||||
CHECK(axis_array.data_type == ArrayDataType::kInt32);
|
||||
@ -336,7 +339,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported activation function "
|
||||
<< LogName(*unary_op);
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
output_float_data[i] = new_value;
|
||||
}
|
||||
@ -351,7 +354,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF("Resolved constant %s to the equivalent constant array",
|
||||
LogName(*unary_op));
|
||||
model->operators.erase(unary_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,17 +25,20 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveFakeQuantArgsFromVars::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto fakequant_it = model->operators.begin() + op_index;
|
||||
auto* fakequant_base_op = fakequant_it->get();
|
||||
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
|
||||
|
||||
if (fakequant_op->minmax) {
|
||||
// Already resolved.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(fakequant_op->inputs.size(), 3);
|
||||
@ -43,7 +46,7 @@ bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
|
||||
// resolved to constant arrays.
|
||||
for (int i = 1; i <= 2; i++) {
|
||||
if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,7 +77,8 @@ bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
|
||||
DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model);
|
||||
}
|
||||
fakequant_op->inputs.resize(1);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,20 +24,25 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveGatherAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto* gather_op = model->operators[op_index].get();
|
||||
if (gather_op->type != OperatorType::kGather) return false;
|
||||
if (gather_op->type != OperatorType::kGather)
|
||||
return ::tensorflow::Status::OK();
|
||||
auto* op = static_cast<GatherOperator*>(gather_op);
|
||||
|
||||
if (op->axis) {
|
||||
// Attributes already resolved
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (op->inputs.size() != 3) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
|
||||
if (op->inputs.size() != 3) return ::tensorflow::Status::OK();
|
||||
if (!IsConstantParameterArray(*model, op->inputs[2]))
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
const auto& indices_array = model->GetArray(op->inputs[2]);
|
||||
if (!indices_array.has_shape()) return false;
|
||||
if (!indices_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
const auto& axis_data = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
CHECK_EQ(axis_data.size(), 1)
|
||||
<< "Multidimensional gather not supported on " << LogName(*op);
|
||||
@ -47,7 +52,8 @@ bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
|
||||
DeleteArrayIfUsedOnce(op->inputs[2], model);
|
||||
op->inputs.resize(2);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -51,27 +51,30 @@ void FillArrayWithZeros(Array* array) {
|
||||
// Removes a multiplication by array of constant zeros by making the output
|
||||
// array an array of constant zeros and removing the input arrays if they are no
|
||||
// longer needed.
|
||||
bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveMultiplyByZero::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto mul_it = model->operators.begin() + op_index;
|
||||
auto* mul_op = mul_it->get();
|
||||
if (mul_op->type != OperatorType::kMul) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto& output_array_name = mul_op->outputs[0];
|
||||
auto& output_array = model->GetArray(output_array_name);
|
||||
|
||||
if (!IsDiscardableArray(*model, output_array_name)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (output_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until the output type has been set by PropagateArrayDataTypes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Yield if the output shape is not known yet.
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// This transformation only handles the case where one operand is all 0's and
|
||||
@ -83,12 +86,12 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
|
||||
};
|
||||
if (!is_input_constant[0] && !is_input_constant[1]) {
|
||||
// Neither input is constant, so nothing we can resolve here.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (is_input_constant[0] && is_input_constant[1]) {
|
||||
// Both inputs are constants. That's a job for constants propagation, not
|
||||
// for us to handle here.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
|
||||
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
|
||||
@ -105,7 +108,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
|
||||
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
|
||||
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>(
|
||||
constant_input_data)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
FillArrayWithZeros<ArrayDataType::kFloat>(&output_array);
|
||||
} break;
|
||||
@ -114,7 +117,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
|
||||
constant_input_array.GetBuffer<ArrayDataType::kUint8>().data;
|
||||
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>(
|
||||
constant_input_data)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
FillArrayWithZeros<ArrayDataType::kUint8>(&output_array);
|
||||
} break;
|
||||
@ -123,7 +126,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
|
||||
constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>(
|
||||
constant_input_data)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
FillArrayWithZeros<ArrayDataType::kInt32>(&output_array);
|
||||
} break;
|
||||
@ -132,14 +135,14 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
|
||||
constant_input_array.GetBuffer<ArrayDataType::kInt64>().data;
|
||||
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>(
|
||||
constant_input_data)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
FillArrayWithZeros<ArrayDataType::kInt64>(&output_array);
|
||||
} break;
|
||||
default:
|
||||
AddMessageF(
|
||||
"Cannot resolve multiply by 0 because of unsupported data type\n");
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Erase input arrays to the multiply if no longer used
|
||||
@ -149,7 +152,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
|
||||
// Erase the multiply operator.
|
||||
model->operators.erase(mul_it);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,19 +24,23 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolvePadAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto pad_it = model->operators.begin() + op_index;
|
||||
auto* pad_op = pad_it->get();
|
||||
if (pad_op->type != OperatorType::kPad) return false;
|
||||
if (pad_op->type != OperatorType::kPad) return ::tensorflow::Status::OK();
|
||||
|
||||
auto* op = static_cast<PadOperator*>(pad_op);
|
||||
if (!op->left_padding.empty()) return false;
|
||||
if (!op->left_padding.empty()) return ::tensorflow::Status::OK();
|
||||
|
||||
CHECK_EQ(op->inputs.size(), 2);
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1]))
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
const auto& array = model->GetArray(op->inputs[1]);
|
||||
if (!array.has_shape()) return false;
|
||||
if (!array.has_shape()) return ::tensorflow::Status::OK();
|
||||
|
||||
const std::vector<int>& dims = array.shape().dims();
|
||||
CHECK_EQ(dims.size(), 2);
|
||||
@ -50,6 +54,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// TODO(dkalenichenko): Delete the extra input?
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
} // namespace toco
|
||||
|
@ -24,19 +24,23 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolvePadV2Attributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto pad_it = model->operators.begin() + op_index;
|
||||
auto* pad_op = pad_it->get();
|
||||
if (pad_op->type != OperatorType::kPadV2) return false;
|
||||
if (pad_op->type != OperatorType::kPadV2) return ::tensorflow::Status::OK();
|
||||
|
||||
auto* op = static_cast<PadV2Operator*>(pad_op);
|
||||
if (!op->left_padding.empty()) return false;
|
||||
if (!op->left_padding.empty()) return ::tensorflow::Status::OK();
|
||||
|
||||
CHECK_EQ(op->inputs.size(), 3);
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1]))
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
const auto& array = model->GetArray(op->inputs[1]);
|
||||
if (!array.has_shape()) return false;
|
||||
if (!array.has_shape()) return ::tensorflow::Status::OK();
|
||||
|
||||
const std::vector<int>& dims = array.shape().dims();
|
||||
CHECK_EQ(dims.size(), 2);
|
||||
@ -50,6 +54,7 @@ bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// TODO(dkalenichenko): Delete the extra input?
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
} // namespace toco
|
||||
|
@ -39,23 +39,37 @@ bool ResolveAttributes(Model* model, T* op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveReduceAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
Operator* op = model->operators[op_index].get();
|
||||
switch (op->type) {
|
||||
case OperatorType::kMean:
|
||||
return ResolveAttributes(model, static_cast<MeanOperator*>(op));
|
||||
*modified = ResolveAttributes(model, static_cast<MeanOperator*>(op));
|
||||
return ::tensorflow::Status::OK();
|
||||
case OperatorType::kSum:
|
||||
return ResolveAttributes(model, static_cast<TensorFlowSumOperator*>(op));
|
||||
*modified =
|
||||
ResolveAttributes(model, static_cast<TensorFlowSumOperator*>(op));
|
||||
return ::tensorflow::Status::OK();
|
||||
case OperatorType::kReduceProd:
|
||||
return ResolveAttributes(model, static_cast<TensorFlowProdOperator*>(op));
|
||||
*modified =
|
||||
ResolveAttributes(model, static_cast<TensorFlowProdOperator*>(op));
|
||||
return ::tensorflow::Status::OK();
|
||||
case OperatorType::kReduceMin:
|
||||
return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
|
||||
*modified =
|
||||
ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
|
||||
return ::tensorflow::Status::OK();
|
||||
case OperatorType::kReduceMax:
|
||||
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
|
||||
*modified =
|
||||
ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
|
||||
return ::tensorflow::Status::OK();
|
||||
case OperatorType::kAny:
|
||||
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
|
||||
*modified =
|
||||
ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
|
||||
return ::tensorflow::Status::OK();
|
||||
default:
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,11 +78,13 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
|
||||
}
|
||||
}
|
||||
|
||||
bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveReorderAxes::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
if (op->type != OperatorType::kReorderAxes) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
|
||||
|
||||
@ -93,11 +95,11 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
|
||||
auto& input_array = model->GetArray(input_array_name);
|
||||
auto& output_array = model->GetArray(output_array_name);
|
||||
if (!input_array.buffer) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Yield until output dims have been resolved.
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Reorder the input array dims and buffer data
|
||||
if (input_array.buffer->type == ArrayDataType::kFloat) {
|
||||
@ -120,7 +122,8 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
|
||||
DeleteOpAndArraysIfUnused(model, op);
|
||||
RenameArray(model, output_array_name, input_array_name);
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,25 +25,29 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveReshapeAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto reshape_it = model->operators.begin() + op_index;
|
||||
auto* reshape_op = reshape_it->get();
|
||||
if (reshape_op->type != OperatorType::kReshape) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
auto* op = static_cast<TensorFlowReshapeOperator*>(reshape_op);
|
||||
|
||||
if (!op->shape.empty()) return false;
|
||||
if (!op->shape.empty()) return ::tensorflow::Status::OK();
|
||||
|
||||
if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
|
||||
const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]);
|
||||
op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
}
|
||||
|
||||
if (op->shape.empty()) return false;
|
||||
if (op->shape.empty()) return ::tensorflow::Status::OK();
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,29 +24,35 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveSliceAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto slice_it = model->operators.begin() + op_index;
|
||||
auto* slice_op = slice_it->get();
|
||||
if (slice_op->type != OperatorType::kSlice) return false;
|
||||
if (slice_op->type != OperatorType::kSlice) return ::tensorflow::Status::OK();
|
||||
|
||||
auto* op = static_cast<SliceOperator*>(slice_op);
|
||||
if (!op->begin.empty()) return false;
|
||||
if (!op->begin.empty()) return ::tensorflow::Status::OK();
|
||||
|
||||
CHECK_EQ(op->inputs.size(), 3);
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1]))
|
||||
return ::tensorflow::Status::OK();
|
||||
if (!IsConstantParameterArray(*model, op->inputs[2]))
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
const auto& begin_array = model->GetArray(op->inputs[1]);
|
||||
if (!begin_array.has_shape()) return false;
|
||||
if (!begin_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
|
||||
const auto& size_array = model->GetArray(op->inputs[2]);
|
||||
if (!size_array.has_shape()) return false;
|
||||
if (!size_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
|
||||
op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
op->size = size_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
|
||||
// TODO(dkalenichenko): Delete the extra inputs?
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
} // namespace toco
|
||||
|
@ -24,16 +24,20 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveSpaceToBatchNDAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto op_it = model->operators.begin() + op_index;
|
||||
if (op_it->get()->type != OperatorType::kSpaceToBatchND) return false;
|
||||
if (op_it->get()->type != OperatorType::kSpaceToBatchND)
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
auto* op = static_cast<SpaceToBatchNDOperator*>(op_it->get());
|
||||
|
||||
// The attributes are resolved only when the 3 attributes (block_shape,
|
||||
// before_paddings, after_paddings) are all constant.
|
||||
if (!op->block_shape.empty()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const int block_shape_index = 1;
|
||||
@ -42,16 +46,16 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
|
||||
CHECK_EQ(op->inputs.size(), 3);
|
||||
if (!IsConstantParameterArray(*model, op->inputs[block_shape_index]) ||
|
||||
!IsConstantParameterArray(*model, op->inputs[paddings_index]))
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
// Handle paddings.
|
||||
const auto& paddings_array = model->GetArray(op->inputs[paddings_index]);
|
||||
if (!paddings_array.has_shape()) return false;
|
||||
if (!paddings_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
const std::vector<int>& paddings_dims = paddings_array.shape().dims();
|
||||
if (paddings_dims.size() != 2) {
|
||||
// Code only handles padding of 2 dimensions. Perhaps another transformation
|
||||
// will delete this op.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const std::vector<int>& paddings_buffer =
|
||||
paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
@ -63,7 +67,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
|
||||
// Handle block_shape.
|
||||
const auto& block_shape_array =
|
||||
model->GetArray(op->inputs[block_shape_index]);
|
||||
if (!block_shape_array.has_shape()) return false;
|
||||
if (!block_shape_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
|
||||
CHECK_EQ(block_shape_dims.size(), 1);
|
||||
const std::vector<int>& block_shape_buffer =
|
||||
@ -72,7 +76,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
|
||||
op->block_shape.push_back(block_shape_buffer[i]);
|
||||
}
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -25,10 +25,13 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveSqueezeAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto* squeeze_op = model->operators[op_index].get();
|
||||
if (squeeze_op->type != OperatorType::kSqueeze) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
DCHECK_EQ(squeeze_op->inputs.size(), 1);
|
||||
DCHECK_EQ(squeeze_op->outputs.size(), 1);
|
||||
@ -42,10 +45,11 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
|
||||
"Reshape op",
|
||||
LogName(*squeeze_op));
|
||||
|
||||
return RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -37,40 +37,47 @@ int PadAttributeArray(Array* attribute_array, std::vector<int> pad_values,
|
||||
return mask;
|
||||
}
|
||||
|
||||
bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveStridedSliceAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto slice_it = model->operators.begin() + op_index;
|
||||
auto* slice_op = slice_it->get();
|
||||
if (slice_op->type != OperatorType::kStridedSlice) return false;
|
||||
if (slice_op->type != OperatorType::kStridedSlice)
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
auto* op = static_cast<StridedSliceOperator*>(slice_op);
|
||||
if (!op->start_indices.empty()) {
|
||||
// We have already resolved these attributes
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(op->inputs.size(), 4);
|
||||
const auto& input_array = model->GetArray(op->inputs[0]);
|
||||
if (!input_array.has_shape()) {
|
||||
// We require the dimensionality of the input to pad the indices
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
auto& start_array = model->GetArray(op->inputs[1]);
|
||||
if (!start_array.has_shape()) return false;
|
||||
if (!start_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
|
||||
// Only 1-4D arrays are supported for now.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
auto& stop_array = model->GetArray(op->inputs[2]);
|
||||
if (!stop_array.has_shape()) return false;
|
||||
if (!stop_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
|
||||
auto& stride_array = model->GetArray(op->inputs[3]);
|
||||
if (!stride_array.has_shape()) return false;
|
||||
if (!stride_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1]))
|
||||
return ::tensorflow::Status::OK();
|
||||
if (!IsConstantParameterArray(*model, op->inputs[2]))
|
||||
return ::tensorflow::Status::OK();
|
||||
if (!IsConstantParameterArray(*model, op->inputs[3]))
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
int num_input_axes = input_array.shape().dimensions_count();
|
||||
int start_indices_size = start_array.shape().dims(0);
|
||||
@ -112,6 +119,7 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
|
||||
op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
} // namespace toco
|
||||
|
@ -25,12 +25,15 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveTensorFlowConcat::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto concat_it = model->operators.begin() + op_index;
|
||||
const auto* tf_concat_op = concat_it->get();
|
||||
if (tf_concat_op->type != OperatorType::kConcat &&
|
||||
tf_concat_op->type != OperatorType::kConcatV2) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_GE(tf_concat_op->inputs.size(), 2);
|
||||
@ -54,7 +57,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
|
||||
if (!axis_array.buffer) {
|
||||
AddMessageF("Waiting for the axis of %s to be resolved to a constant",
|
||||
LogName(*tf_concat_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK(axis_array.data_type == ArrayDataType::kInt32);
|
||||
@ -79,7 +82,8 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
|
||||
}
|
||||
// Remove the TensorFlowConcat op
|
||||
model->operators.erase(concat_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -55,10 +55,13 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveTensorFlowMatMul::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto matmul_it = model->operators.begin() + op_index;
|
||||
if (matmul_it->get()->type != OperatorType::kMatMul) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* matmul_op =
|
||||
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
|
||||
@ -73,7 +76,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
|
||||
"Not replacing %s by a FullyConnected operator, because it has "
|
||||
"the transpose_a attribute",
|
||||
LogName(*matmul_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Reorder the axes on the second input. TensorFlow uses row-major ordering
|
||||
@ -198,7 +201,8 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
// erase the MatMul operator
|
||||
model->operators.erase(matmul_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,11 +24,14 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveTensorFlowMerge::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto merge_it = model->operators.begin() + op_index;
|
||||
const auto* merge_op = merge_it->get();
|
||||
if (merge_op->type != OperatorType::kMerge) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// We need to yield until this Merge node has only 1 input, which will mean
|
||||
@ -37,7 +40,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
|
||||
// non-selected inputs, so that at some point there will be only 1 input left.
|
||||
if (merge_op->inputs.size() > 1) {
|
||||
AddMessageF("Waiting for %s to be resolved", LogName(*merge_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// Now that the merge node has 1 input exactly, it is the same as an Identity
|
||||
@ -57,7 +60,8 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF("Removing already-resolved %s", LogName(*merge_op));
|
||||
model->EraseArray(merge_op->outputs[0]);
|
||||
model->operators.erase(merge_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,11 +24,14 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveTensorFlowSwitch::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto switch_it = model->operators.begin() + op_index;
|
||||
const auto* switch_op = switch_it->get();
|
||||
if (switch_op->type != OperatorType::kSwitch) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
CHECK_EQ(switch_op->inputs.size(), 2);
|
||||
@ -40,7 +43,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
|
||||
AddMessageF(
|
||||
"Waiting for the boolean predicate of %s to be resolved to a constant",
|
||||
LogName(*switch_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// The predicate should be boolean, and should consist of a single value.
|
||||
@ -119,7 +122,8 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
|
||||
// Remove the switch node itself.
|
||||
AddMessageF("Removing already-resolved %s", LogName(*switch_op));
|
||||
model->operators.erase(switch_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,19 +24,24 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ResolveTransposeAttributes::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto op_it = model->operators.begin() + op_index;
|
||||
if (op_it->get()->type != OperatorType::kTranspose) return false;
|
||||
if (op_it->get()->type != OperatorType::kTranspose)
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
auto* op = static_cast<TransposeOperator*>(op_it->get());
|
||||
if (!op->perm.empty()) return false;
|
||||
if (!op->perm.empty()) return ::tensorflow::Status::OK();
|
||||
|
||||
CHECK_EQ(op->inputs.size(), 2);
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
|
||||
if (!IsConstantParameterArray(*model, op->inputs[1]))
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
// Handling perm.
|
||||
const auto& perm_array = model->GetArray(op->inputs[1]);
|
||||
if (!perm_array.has_shape()) return false;
|
||||
if (!perm_array.has_shape()) return ::tensorflow::Status::OK();
|
||||
|
||||
const std::vector<int>& perm_dims = perm_array.shape().dims();
|
||||
CHECK_EQ(perm_dims.size(), 1);
|
||||
@ -47,7 +52,8 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
|
||||
op->perm.push_back(perm_buffer[i]);
|
||||
}
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -24,15 +24,17 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
Operator* op = model->operators[op_index].get();
|
||||
if (op->type != OperatorType::kFullyConnected) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
FullyConnectedOperator* fc_op = static_cast<FullyConnectedOperator*>(op);
|
||||
// Exit if this FC op already has shuffled weights
|
||||
if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const Array& input_array = model->GetArray(fc_op->inputs[0]);
|
||||
const string& weights_name = fc_op->inputs[1];
|
||||
@ -46,11 +48,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
|
||||
output_array.data_type != ArrayDataType::kInt16 ||
|
||||
!input_array.quantization_params || !weights_array.quantization_params ||
|
||||
!output_array.quantization_params) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Exit if the shapes aren't known
|
||||
if (!input_array.has_shape() || !weights_array.has_shape()) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Exit if, based on the known shapes, this FC op is not a GEMV.
|
||||
// The shuffling of FC weights is only useful to enable fast GEMV paths.
|
||||
@ -64,7 +66,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
|
||||
"the input shape is not 1D or 2D (possibly with additional inner "
|
||||
"dimensions of size 1)",
|
||||
LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
if (input_shape.dims(0) != 1 && input_shape.dims(0) != 4) {
|
||||
@ -73,7 +75,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
|
||||
"the input shape's leading dimension, i.e. the 'batch size', is not "
|
||||
"equal to 1 or 4",
|
||||
LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Exit if the weights shape isn't an integral multiple of the shuffled
|
||||
// block shape, 4x16. We don't want to have to write code dealing with
|
||||
@ -88,7 +90,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
|
||||
// two.
|
||||
const Shape& weights_shape = weights_array.shape();
|
||||
if (weights_shape.dimensions_count() != 2) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const int rows = weights_shape.dims(0);
|
||||
const int cols = weights_shape.dims(1);
|
||||
@ -97,11 +99,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
|
||||
"Not applying experimental shuffling to the weights of %s because its "
|
||||
"shape isn't a multiple of the shuffling block shape, 4x16",
|
||||
LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Exit if the weights aren't already a constant array.
|
||||
if (!weights_array.buffer) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Exit if the weights are used by more than one op.
|
||||
if (CountOpsWithInput(*model, weights_name) != 1) {
|
||||
@ -109,7 +111,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
|
||||
"Not applying experimental shuffling to the weights of %s because that "
|
||||
"array is consumed by other operators",
|
||||
LogName(*op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
// Compute the shuffled weights
|
||||
auto& weights_data =
|
||||
@ -152,7 +154,8 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
|
||||
shuffled_input_workspace_array.GetOrCreateQuantizationParams() =
|
||||
input_array.GetQuantizationParams();
|
||||
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -166,7 +166,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
|
||||
GraphTransformationsSet graph_transformation_set;
|
||||
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
|
||||
EXPECT_THAT(model.GetArrayMap().size(), 5);
|
||||
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
|
||||
bool modified;
|
||||
ASSERT_TRUE((*graph_transformation_set.begin())
|
||||
->Run(&model, /*op_index=*/0, &modified)
|
||||
.ok());
|
||||
EXPECT_THAT(model.GetArrayMap().size(), 1);
|
||||
|
||||
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
|
||||
@ -185,7 +188,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
|
||||
GraphTransformationsSet graph_transformation_set;
|
||||
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
|
||||
EXPECT_THAT(model.GetArrayMap().size(), 5);
|
||||
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
|
||||
bool modified;
|
||||
ASSERT_TRUE((*graph_transformation_set.begin())
|
||||
->Run(&model, /*op_index=*/0, &modified)
|
||||
.ok());
|
||||
EXPECT_THAT(model.GetArrayMap().size(), 1);
|
||||
|
||||
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
|
||||
@ -204,7 +210,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
|
||||
GraphTransformationsSet graph_transformation_set;
|
||||
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
|
||||
EXPECT_THAT(model.GetArrayMap().size(), 5);
|
||||
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
|
||||
bool modified;
|
||||
ASSERT_TRUE((*graph_transformation_set.begin())
|
||||
->Run(&model, /*op_index=*/0, &modified)
|
||||
.ok());
|
||||
EXPECT_THAT(model.GetArrayMap().size(), 1);
|
||||
|
||||
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
|
||||
|
@ -50,7 +50,8 @@ void RunResolveSum(const std::vector<float>& input,
|
||||
sum_op->inputs = {"input0", "input1"};
|
||||
sum_op->outputs = {"output"};
|
||||
model.operators.push_back(std::move(sum_op));
|
||||
ResolveConstantUnaryOperator().Run(&model, 0);
|
||||
bool modified;
|
||||
ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok());
|
||||
EXPECT_EQ(model.GetArray("output").GetBuffer<ArrayDataType::kFloat>().data,
|
||||
expected_output);
|
||||
EXPECT_EQ(model.GetArray("output").shape().dims(), output_shape);
|
||||
|
@ -25,13 +25,16 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status UnfuseActivationFunctions::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
auto* op = it->get();
|
||||
|
||||
// If a conv operation has an im2col array, yield: it should be dropped first.
|
||||
if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
Operator* ac_op = nullptr;
|
||||
@ -46,7 +49,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
ac_op = new Relu1Operator;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
// At this point we know that the op has a fused activation function. At the
|
||||
@ -74,7 +77,8 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
|
||||
|
||||
ac_op->inputs = {tmp_array_name};
|
||||
op->outputs = {tmp_array_name};
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -22,7 +22,10 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
// Collapses a partitioned tf.nn.embedding_lookup back into a single Gather.
|
||||
// https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup
|
||||
// This transform attempts to identify the len(params) > 1 case and collapse
|
||||
@ -47,7 +50,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
// First look for the final DynamicStitch.
|
||||
auto op_it = model->operators.begin() + op_index;
|
||||
if (op_it->get()->type != OperatorType::kDynamicStitch) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get());
|
||||
|
||||
@ -72,7 +75,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
"Skipping because indices input %s into "
|
||||
"%s is unexpected",
|
||||
LogName(*op), LogName(*stitch_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!indices_partition_op) {
|
||||
indices_partition_op = static_cast<DynamicPartitionOperator*>(op);
|
||||
@ -83,7 +86,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
"Skipping because indices input %s into "
|
||||
"%s is from a different source op than others",
|
||||
LogName(*op), LogName(*stitch_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -92,12 +95,12 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
// The data for the indices must be a constant range of the array shape.
|
||||
if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) {
|
||||
AddMessageF("Skipping because indices partition data is non-constant");
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]);
|
||||
if (indices_data_array.data_type == ArrayDataType::kNone) {
|
||||
// Yield until data types are propagated.
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK(indices_data_array.data_type == ArrayDataType::kInt32)
|
||||
<< "Indices partition inputs must be int32";
|
||||
@ -117,7 +120,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
"Skipping because data input %s into %s "
|
||||
"is unexpected",
|
||||
LogName(*op), LogName(*stitch_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
gather_ops.push_back(static_cast<GatherOperator*>(op));
|
||||
}
|
||||
@ -132,7 +135,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
"Skipping because data input %s into "
|
||||
"%s is unexpected",
|
||||
LogName(*op), LogName(*gather_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
if (!data_partition_op) {
|
||||
data_partition_op = static_cast<DynamicPartitionOperator*>(op);
|
||||
@ -143,7 +146,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
"Skipping because data input %s into "
|
||||
"%s is from a different source op than others",
|
||||
LogName(*op), LogName(*gather_op));
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -236,7 +239,8 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
|
||||
DeleteOpAndArraysIfUnused(model, indices_partition_op);
|
||||
DeleteOpAndArraysIfUnused(model, data_partition_op);
|
||||
DeleteOpAndArraysIfUnused(model, stitch_op);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -36,10 +36,12 @@ namespace toco {
|
||||
// slice_c = tf.matmul(slice_a, slice_b)
|
||||
// result_slices[bat] = slice_c
|
||||
// result = tf.stack(result_slices)
|
||||
bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
|
||||
::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto batch_op_it = model->operators.begin() + op_index;
|
||||
if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
|
||||
return false;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
const auto* batch_op =
|
||||
static_cast<const BatchMatMulOperator*>(batch_op_it->get());
|
||||
@ -47,7 +49,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
|
||||
// We must have the shape of at least one input to know our batch size.
|
||||
const auto& input_array_a = model->GetArray(batch_op->inputs[0]);
|
||||
const auto& input_array_b = model->GetArray(batch_op->inputs[1]);
|
||||
if (!input_array_a.has_shape() || !input_array_b.has_shape()) return false;
|
||||
if (!input_array_a.has_shape() || !input_array_b.has_shape())
|
||||
return ::tensorflow::Status::OK();
|
||||
|
||||
// We only support the rank 3 case. If you are batching on rank > 3 you'll
|
||||
// have to figure that out.
|
||||
@ -66,7 +69,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
|
||||
batch_op_it = matmul_op_it + 1;
|
||||
CHECK_EQ(batch_op_it->get(), batch_op);
|
||||
model->operators.erase(batch_op_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
CHECK_EQ(input_array_a.shape().dimensions_count(), 3)
|
||||
<< "Input arrays must have rank 3";
|
||||
@ -167,7 +171,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
|
||||
CHECK(batch_op_it != model->operators.end());
|
||||
CHECK(batch_op_it->get() == batch_op);
|
||||
model->operators.erase(batch_op_it);
|
||||
return true;
|
||||
*modified = true;
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
Loading…
Reference in New Issue
Block a user