Return ::tensorflow::Status in Toco Graph Transformations.

PiperOrigin-RevId: 216392908
This commit is contained in:
Yu-Cheng Ling 2018-10-09 11:38:15 -07:00 committed by TensorFlower Gardener
parent 931353c5f7
commit 12e164d1e7
94 changed files with 1003 additions and 617 deletions

View File

@ -25,10 +25,13 @@ limitations under the License.
namespace toco {
bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ConvertExpandDimsToReshape::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto expand_it = model->operators.begin() + op_index;
if (expand_it->get()->type != OperatorType::kExpandDims) {
return false;
return ::tensorflow::Status::OK();
}
ExpandDimsOperator* expand_op =
static_cast<ExpandDimsOperator*>(expand_it->get());
@ -38,18 +41,18 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
const auto& input_array = model->GetArray(expand_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return false;
return ::tensorflow::Status::OK();
}
const auto& axis_array = model->GetArray(expand_op->inputs[1]);
if (!axis_array.has_shape()) {
// Yield until input axis array shape has been resolved.
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1);
if (!axis_array.buffer) {
// Yield until the input axis array is constant
return false;
return ::tensorflow::Status::OK();
}
int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
std::vector<int> reshape_dims(input_array.shape().dims());
@ -90,7 +93,8 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(expand_it->get(), expand_op);
model->operators.erase(expand_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,29 +24,32 @@ limitations under the License.
namespace toco {
bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto conv_it = model->operators.begin() + op_index;
if (conv_it->get()->type != OperatorType::kConv) {
return false;
return ::tensorflow::Status::OK();
}
const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
if (conv_op->stride_width != conv_op->stride_height) {
return false;
return ::tensorflow::Status::OK();
}
if ((conv_op->dilation_width_factor != 1) ||
(conv_op->dilation_height_factor != 1)) {
// Depthwise conv does not support dilation
return false;
return ::tensorflow::Status::OK();
}
auto& input_array = model->GetArray(conv_op->inputs[0]);
if (!input_array.has_shape()) {
// Shapes not propagated yet
return false;
return ::tensorflow::Status::OK();
}
if (input_array.shape().dims(3) != 1) {
// Not a pure convolution: Conv does accumulation across the depth
// dimension.
return false;
return ::tensorflow::Status::OK();
}
const auto& weights_name = conv_op->inputs[1];
@ -56,15 +59,15 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
"Not changing %s to DepthwiseConv because the weights is consumed by "
"another op.",
LogName(*conv_op));
return false;
return ::tensorflow::Status::OK();
}
auto& weights_array = model->GetArray(weights_name);
if (!weights_array.buffer) {
// Yield until the weights are resolved as a constant array.
return false;
return ::tensorflow::Status::OK();
}
if (weights_array.data_type != ArrayDataType::kFloat) {
return false;
return ::tensorflow::Status::OK();
}
// At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
AddMessageF(
@ -112,7 +115,8 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
}
*weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
weights_buffer.data = depthwise_conv_weights_data;
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -86,9 +86,12 @@ TransposeOperator* CreateTransposeFromReorderAxes(
// Converts ReorderAxes into Transpose and Reshape which are compatible with the
// TFLite interpreter.
bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ConvertReorderAxes::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
auto reorder_it = model->operators.begin() + op_index;
if (reorder_it->get()->type != OperatorType::kReorderAxes) return false;
if (reorder_it->get()->type != OperatorType::kReorderAxes)
return ::tensorflow::Status::OK();
auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
CHECK_EQ(reorder_op->inputs.size(), 1);
@ -113,8 +116,9 @@ bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
// Yield if input array contains constants or if output array size has not
// been adjusted to reflect the permutations in ReorderAxes. ReorderAxes will
// be merged into a constant array when possible.
if (IsConstantParameterArray(*model, constant_input_array_name)) return false;
if (!output_array.has_shape()) return false;
if (IsConstantParameterArray(*model, constant_input_array_name))
return ::tensorflow::Status::OK();
if (!output_array.has_shape()) return ::tensorflow::Status::OK();
const auto input_axes_order = reorder_op->input_axes_order;
const auto output_axes_order = reorder_op->output_axes_order;
@ -143,7 +147,8 @@ bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(reorder_it->get(), reorder_op);
model->operators.erase(reorder_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -30,10 +30,13 @@ namespace toco {
// means that the data layout will never change with this op, just the shape.
// By converting these to reshapes once we have run shape propagation we allow
// standard reshape optimization transforms to do their magic.
bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ConvertSqueezeToReshape::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto squeeze_it = model->operators.begin() + op_index;
if (squeeze_it->get()->type != OperatorType::kSqueeze) {
return false;
return ::tensorflow::Status::OK();
}
auto squeeze_op = static_cast<SqueezeOperator*>(squeeze_it->get());
CHECK_EQ(squeeze_op->inputs.size(), 1);
@ -42,16 +45,16 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
const auto& input_array = model->GetArray(squeeze_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return false;
return ::tensorflow::Status::OK();
}
if (input_array.shape().dimensions_count() == 0) {
// Input array cannot be 0-D.
return false;
return ::tensorflow::Status::OK();
}
if (!model->HasArray(squeeze_op->outputs[0]) ||
!model->GetArray(squeeze_op->outputs[0]).has_shape()) {
// Yield until shape propagation has set the output shape for us.
return false;
return ::tensorflow::Status::OK();
}
// We use the output shape that has been calculated by shape propagation.
@ -59,7 +62,7 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
// Empty shapes will not work as empty data arrays.
if (output_shape.dimensions_count() == 0) {
return false;
return ::tensorflow::Status::OK();
}
auto* reshape_op = new TensorFlowReshapeOperator;
@ -79,7 +82,8 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(squeeze_it->get(), squeeze_op);
model->operators.erase(squeeze_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -20,10 +20,13 @@ namespace toco {
// This pass will convert an AddN operator with only 2 inputs into a regular Add
// operator, to which more optimizations may apply.
bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ConvertTrivialAddNToAdd::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto addn_it = model->operators.begin() + op_index;
if (addn_it->get()->type != OperatorType::kAddN) {
return false;
return ::tensorflow::Status::OK();
}
AddNOperator* addn_op = static_cast<AddNOperator*>(addn_it->get());
CHECK_GE(addn_op->inputs.size(), 2);
@ -31,7 +34,7 @@ bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
// We only reduce AddN with N=2 to a regular Add.
if (addn_op->inputs.size() != 2) {
return false;
return ::tensorflow::Status::OK();
}
// Copy inputs & outputs to regular Add.
@ -45,7 +48,8 @@ bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
addn_it = add_it + 1;
CHECK_EQ(addn_it->get(), addn_op);
model->operators.erase(addn_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,27 +25,30 @@ limitations under the License.
namespace toco {
bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ConvertTrivialPackToReshape::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto pack_it = model->operators.begin() + op_index;
if (pack_it->get()->type != OperatorType::kPack) {
return false;
return ::tensorflow::Status::OK();
}
auto* pack_op = static_cast<PackOperator*>(pack_it->get());
if (pack_op->inputs.size() > 1) {
// Not trivial.
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(pack_op->outputs.size(), 1);
const auto& input_array = model->GetArray(pack_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return false;
return ::tensorflow::Status::OK();
}
if (input_array.shape().dimensions_count() == 0) {
// Input array cannot be 0-D.
// (Unsure if this is TF behavior, but was required to get a test to pass.)
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Converting trivial %s to a reshape", LogName(*pack_op));
@ -75,7 +78,8 @@ bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(pack_it->get(), pack_op);
model->operators.erase(pack_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -21,10 +21,13 @@ limitations under the License.
namespace toco {
bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ConvertTrivialTileToConcat::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto tile_it = model->operators.begin() + op_index;
if (tile_it->get()->type != OperatorType::kTile) {
return false;
return ::tensorflow::Status::OK();
}
auto* tile_op = static_cast<TransposeOperator*>(tile_it->get());
@ -34,13 +37,13 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
if (!input_array.has_shape() || !multiples_array.has_shape() ||
!output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
return false;
return ::tensorflow::Status::OK();
}
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
if (!multiples_array.buffer) {
// Yield until the multiples is constant.
return false;
return ::tensorflow::Status::OK();
}
std::vector<int32> const& multiples =
multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
@ -59,7 +62,7 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
// The tile is non-trivial. Good luck.
AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)",
LogName(*tile_op));
return false;
return ::tensorflow::Status::OK();
}
// The tile is like a concat.
@ -88,7 +91,8 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
CHECK_EQ(tile_it->get(), tile_op);
model->operators.erase(tile_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -48,10 +48,13 @@ bool TransposeAffectsMemoryOrder(std::vector<int> perm,
} // namespace
bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ConvertTrivialTransposeToReshape::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto transpose_it = model->operators.begin() + op_index;
if (transpose_it->get()->type != OperatorType::kTranspose) {
return false;
return ::tensorflow::Status::OK();
}
TransposeOperator* transpose_op =
static_cast<TransposeOperator*>(transpose_it->get());
@ -60,14 +63,14 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
const auto& output_array = model->GetArray(transpose_op->outputs[0]);
if (!input_array.has_shape() || !output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
return false;
return ::tensorflow::Status::OK();
}
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
// Check that the permutation has propogated.
std::vector<int> const& perm = transpose_op->perm;
if (perm.empty()) {
return false;
return ::tensorflow::Status::OK();
}
// This transpose is trivial if non-unitary dimensions remain in the same
@ -76,7 +79,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
std::vector<int> const& output_dims = output_array.shape().dims();
if (TransposeAffectsMemoryOrder(perm, input_dims)) {
return false;
return ::tensorflow::Status::OK();
}
// This transpose is trivial. Replace it with a Reshape op.
@ -109,7 +112,8 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(transpose_it->get(), transpose_op);
model->operators.erase(transpose_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -73,18 +73,22 @@ bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
return true;
}
bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
::tensorflow::Status CreateIm2colArrays::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
switch (op->type) {
case OperatorType::kConv:
return ProcessConvOperator(model, static_cast<ConvOperator*>(op));
*modified = ProcessConvOperator(model, static_cast<ConvOperator*>(op));
return ::tensorflow::Status::OK();
case OperatorType::kTransposeConv:
return ProcessTransposeConvOperator(
*modified = ProcessTransposeConvOperator(
model, static_cast<TransposeConvOperator*>(op));
return ::tensorflow::Status::OK();
default:
return false;
return ::tensorflow::Status::OK();
}
}

View File

@ -186,24 +186,27 @@ bool DequantizeArray(const string& array_name,
} // namespace
bool Dequantize::Run(Model* model, std::size_t op_index) {
::tensorflow::Status Dequantize::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
const auto op_it = model->operators.begin() + op_index;
auto* op = op_it->get();
if (op->type == OperatorType::kDequantize) {
auto& input_array = model->GetArray(op->inputs[0]);
if (input_array.data_type == ArrayDataType::kFloat) {
return false;
return ::tensorflow::Status::OK();
}
if (input_array.final_data_type != ArrayDataType::kFloat) {
return false;
return ::tensorflow::Status::OK();
}
input_array.data_type = ArrayDataType::kFloat;
input_array.quantization_params = nullptr;
auto& output_array = model->GetArray(op->outputs[0]);
output_array.data_type = ArrayDataType::kFloat;
output_array.quantization_params = nullptr;
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
std::vector<string> arrays;
@ -220,7 +223,8 @@ bool Dequantize::Run(Model* model, std::size_t op_index) {
}
}
return changed;
*modified = changed;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,21 +25,23 @@ limitations under the License.
namespace toco {
bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
::tensorflow::Status DropFakeQuant::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
const auto fakequant_it = model->operators.begin() + op_index;
auto* fakequant_base_op = fakequant_it->get();
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
return false;
return ::tensorflow::Status::OK();
}
auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
if (!fakequant_op->minmax) {
return false;
return ::tensorflow::Status::OK();
}
const auto& output_array = model->GetArray(fakequant_op->outputs[0]);
if (!output_array.minmax) {
return false;
return ::tensorflow::Status::OK();
}
// Drop min/max inputs
@ -50,7 +52,8 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
}
fakequant_op->inputs.resize(1);
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -19,15 +19,17 @@ limitations under the License.
namespace toco {
bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
::tensorflow::Status DropIm2colArrays::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
auto conv_it = model->operators.begin() + op_index;
if (conv_it->get()->type != OperatorType::kConv) {
return false;
return ::tensorflow::Status::OK();
}
auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
if (conv_op->outputs.size() < 2) {
// Conv op does not have im2col.
return false;
return ::tensorflow::Status::OK();
}
// Drop the im2col array.
@ -36,7 +38,8 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
conv_op->outputs.resize(1);
AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -62,17 +62,20 @@ bool ProcessLinearOperator(Model* model, Operator* op) {
}
} // namespace
bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) {
::tensorflow::Status EnsureBiasVectors::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
auto* op = model->operators[op_index].get();
if (op->type == OperatorType::kConv ||
op->type == OperatorType::kDepthwiseConv ||
op->type == OperatorType::kFullyConnected) {
if (ProcessLinearOperator(model, op)) {
AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
}
return false;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -108,8 +108,9 @@ namespace toco {
// we can foresee these 'fast int8 kernels' to remain important to have into
// the 2020s.
//
bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
std::size_t op_index) {
::tensorflow::Status EnsureUint8WeightsSafeForFastInt8Kernels::Run(
Model* model, std::size_t op_index, bool* modified) {
*modified = false;
const auto& op = *model->operators[op_index];
int weights_index = 0;
switch (op.type) {
@ -148,16 +149,16 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
// That's why at the moment we only handle operators that use a GEMM
// (Conv, fully-connected --- note that LSTM merely wraps a
// fully-connected operator).
return false;
return ::tensorflow::Status::OK();
}
const string& name = op.inputs[weights_index];
auto& array = model->GetArray(name);
if (!array.buffer) {
return false;
return ::tensorflow::Status::OK();
}
if (array.data_type != ArrayDataType::kUint8) {
return false;
return ::tensorflow::Status::OK();
}
auto& buffer_data = array.GetMutableBuffer<ArrayDataType::kUint8>().data;
@ -212,7 +213,8 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
AddMessageF("Tweaked weights values for %s", LogName(op));
}
return changed;
*modified = changed;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,27 +25,30 @@ limitations under the License.
namespace toco {
bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
::tensorflow::Status FuseActivationFunctions::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto ac_it = model->operators.begin() + op_index;
const auto* ac_op = ac_it->get();
if (ac_op->type != OperatorType::kRelu6 &&
ac_op->type != OperatorType::kRelu1 &&
ac_op->type != OperatorType::kRelu) {
return false;
return ::tensorflow::Status::OK();
}
// Find the op producing the array passed to this activation function
Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]);
if (!op) return false;
if (!op) return ::tensorflow::Status::OK();
if (CountTrueOutputs(*model, *op) > 1) {
AddMessageF(
"Not fusing activation function %s into %s because it has more than "
"one consumed output",
LogName(*ac_op), LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(op->outputs[0], ac_op->inputs[0]);
@ -57,7 +60,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function into %s because it is consumed by more "
"than 1 other operator",
LogName(*ac_op), LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
if (!IsDiscardableArray(*model, op->outputs[0])) {
@ -65,7 +68,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s into %s because output %s it is not "
"discardable",
LogName(*ac_op), LogName(*op), op->outputs[0]);
return false;
return ::tensorflow::Status::OK();
}
if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
@ -73,7 +76,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s into %s because it already has a "
"fused activation function",
LogName(*ac_op), LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
if (!OperatorSupportsFusedActivation(op->type)) {
@ -81,7 +84,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s because the %s op doesn't support "
"it",
LogName(*ac_op), LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Fusing activation function %s into the preceding %s",
@ -98,7 +101,8 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
model->EraseArray(ac_op->inputs[0]);
op->outputs[0] = ac_op->outputs[0];
model->operators.erase(ac_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -150,14 +150,17 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
} // namespace
bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
::tensorflow::Status FuseBinaryIntoFollowingAffine::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@ -175,12 +178,12 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can fuse into a constant.
return false;
return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
return false;
return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@ -192,7 +195,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
if (index_of_constant_input != 1) {
AddMessageF("Not fusing %s because the denominator is not constant",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
}
@ -204,7 +207,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s into the following affine op, because we only know "
"how to do so when the constant operand is a scalar",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
}
@ -212,7 +215,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
FusedActivationFunctionType::kNone) {
AddMessageF("Not fusing %s because it has a fused activation function",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]);
@ -221,7 +224,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because it is not consumed by exactly one other op",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
if (following_op->type != OperatorType::kConv &&
@ -231,14 +234,14 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the following %s is not of one of the supported "
"types",
LogName(*binary_op), LogName(*following_op));
return false;
return ::tensorflow::Status::OK();
}
if (following_op->inputs.size() < 3) {
AddMessageF(
"Not fusing %s because the following %s does not have a bias vector",
LogName(*following_op), LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
const auto& weights = model->GetArray(following_op->inputs[1]);
@ -248,7 +251,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the following %s has non-constant weights or "
"bias arrays",
LogName(*binary_op), LogName(*following_op));
return false;
return ::tensorflow::Status::OK();
}
// Try to fuse the binary params into the following op's params
@ -260,7 +263,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because the following %s does not use VALID padding",
LogName(*binary_op), LogName(*following_op));
return false;
return ::tensorflow::Status::OK();
}
}
if (following_op->type == OperatorType::kDepthwiseConv) {
@ -269,7 +272,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because the following %s does not use VALID padding",
LogName(*binary_op), LogName(*following_op));
return false;
return ::tensorflow::Status::OK();
}
}
FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op,
@ -294,7 +297,8 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -188,14 +188,17 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
}
} // namespace
bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
::tensorflow::Status FuseBinaryIntoPrecedingAffine::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto binary_it = model->operators.begin() + op_index;
const auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@ -213,12 +216,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can fuse into a constant.
return false;
return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
return false;
return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@ -230,7 +233,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
if (index_of_constant_input != 1) {
AddMessageF("Not fusing %s because the denominator is not constant",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
}
@ -239,12 +242,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
if (!preceding_op) {
AddMessageF("Not fusing %s because it is not the output of another op",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
for (const string& output_array : model->flags.output_arrays()) {
if (preceding_op->outputs[0] == output_array) {
return false;
return ::tensorflow::Status::OK();
}
}
@ -255,7 +258,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s is not of one of the supported "
"types",
LogName(*binary_op), LogName(*preceding_op));
return false;
return ::tensorflow::Status::OK();
}
if (preceding_op->fused_activation_function !=
@ -264,14 +267,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has a fused activation "
"function",
LogName(*binary_op), LogName(*preceding_op));
return false;
return ::tensorflow::Status::OK();
}
if (preceding_op->inputs.size() < 3) {
AddMessageF(
"Not fusing %s because the preceding %s does not have a bias vector",
LogName(*binary_op), LogName(*preceding_op));
return false;
return ::tensorflow::Status::OK();
}
const auto& weights_name = preceding_op->inputs[1];
@ -289,14 +292,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has a non-constant bias "
"array",
LogName(*binary_op), LogName(*preceding_op));
return false;
return ::tensorflow::Status::OK();
}
if (count_ops_consuming_bias > 1) {
AddMessageF(
"Not fusing %s because the bias of the preceding %s is consumed by "
"another op",
LogName(*binary_op), LogName(*preceding_op));
return false;
return ::tensorflow::Status::OK();
}
} else {
if (!weights.buffer || !bias.buffer) {
@ -304,14 +307,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has non-constant weights or "
"bias arrays",
LogName(*binary_op), LogName(*preceding_op));
return false;
return ::tensorflow::Status::OK();
}
if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) {
AddMessageF(
"Not fusing %s because the weights or bias of the preceding %s is "
"consumed by another op",
LogName(*binary_op), LogName(*preceding_op));
return false;
return ::tensorflow::Status::OK();
}
}
@ -323,7 +326,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the output of the preceding %s is consumed by "
"another op",
LogName(*binary_op), LogName(*preceding_op));
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
@ -352,7 +355,8 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -51,19 +51,22 @@ bool IsBroadcastingOp(const Model& model, Operator* op) {
// Finds an operation that looks like a broadcast (concat of the same sources
// along the last dimension) and drops it by relying on the ability of certain
// binary ops to perform an implicit broadcast.
bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
::tensorflow::Status FuseBroadcastIntoFollowingBinary::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
// Test for binary ops of types that we know how to resolve
if (binary_op->inputs.size() != 2) {
return false;
return ::tensorflow::Status::OK();
}
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
return false;
return ::tensorflow::Status::OK();
}
// NOTE: either of these ops may be nullptr if the input array is constant.
@ -78,14 +81,14 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
if (!is_op_0_broadcast && !is_op_1_broadcast) {
// Neither input is a broadcast-looking thing.
AddMessageF("Neither input looks broadcasty");
return false;
return ::tensorflow::Status::OK();
} else if (is_op_0_broadcast && is_op_1_broadcast) {
AddMessageF(
"Unable to fuse broadcast into %s as both inputs (%s, %s) are "
"broadcasts",
LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)",
op[1] ? LogName(*op[1]) : "(?)");
return false;
return ::tensorflow::Status::OK();
}
int broadcast_index = is_op_0_broadcast ? 0 : 1;
@ -96,7 +99,8 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0];
// We leave the broadcast op in; it'll get cleaned up if it's not used later.
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -142,7 +142,7 @@ bool GraphTransformationsPass(int increment, Model* model,
for (const auto& transformation : transformations) {
CHECK(!changed_now);
CHECK(transformation->Messages().empty());
changed_now = transformation->Run(model, op_index);
CHECK(transformation->Run(model, op_index, &changed_now).ok());
const char* made_a_change_msg =
changed_now ? "made a change" : "did NOT make a change";
const int log_level =

View File

@ -27,7 +27,8 @@ namespace toco {
class GraphTransformation {
public:
virtual bool Run(Model* model, std::size_t op_index) = 0;
virtual ::tensorflow::Status Run(Model* model, std::size_t op_index,
bool* modified) = 0;
virtual const char* Name() const = 0;
virtual ~GraphTransformation() {}
// Returns the list of messages that this graph transformation
@ -104,11 +105,12 @@ class GraphTransformationsSet {
void RunGraphTransformations(Model* model, const string& message,
const GraphTransformationsSet& transformations);
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
class GTName : public GraphTransformation { \
public: \
bool Run(Model* model, std::size_t op_index) override; \
const char* Name() const override { return #GTName; } \
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
class GTName : public GraphTransformation { \
public: \
::tensorflow::Status Run(Model* model, std::size_t op_index, \
bool* modified) override; \
const char* Name() const override { return #GTName; } \
};
// List of all graph transformations
@ -200,7 +202,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes)
class PropagateDefaultMinMax : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
::tensorflow::Status Run(Model* model, std::size_t op_index,
bool* modified) override;
const char* Name() const override { return "PropagateDefaultMinMax"; }
bool has_any_ranges_defined() const { return !type_ranges_.empty(); }
@ -218,7 +221,8 @@ class PropagateDefaultMinMax : public GraphTransformation {
class RemoveTrivialReshape : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
::tensorflow::Status Run(Model* model, std::size_t op_index,
bool* modified) override;
const char* Name() const override { return "RemoveTrivialReshape"; }
bool treat_expand_dims_as_trivial() const {
return treat_expand_dims_as_trivial_;
@ -233,7 +237,8 @@ class RemoveTrivialReshape : public GraphTransformation {
class ResolveConstantFakeQuant : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
::tensorflow::Status Run(Model* model, std::size_t op_index,
bool* modified) override;
const char* Name() const override { return "ResolveConstantFakeQuant"; }
// True if the num_bits should adjust the final data type.
@ -250,7 +255,8 @@ class ResolveConstantFakeQuant : public GraphTransformation {
class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
::tensorflow::Status Run(Model* model, std::size_t op_index,
bool* modified) override;
const char* Name() const override {
return "EnsureUint8WeightsSafeForFastInt8Kernels";
}
@ -267,7 +273,8 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
class IdentifyDilatedConv : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
::tensorflow::Status Run(Model* model, std::size_t op_index,
bool* modified) override;
const char* Name() const override { return "IdentifyDilatedConv"; }
bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }

View File

@ -372,7 +372,9 @@ bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) {
}
} // namespace
bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
::tensorflow::Status HardcodeMinMax::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
bool changed = false;
@ -467,7 +469,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
if (changed) {
AddMessageF("Hardcoded min-max through %s", LogName(*op));
}
return changed;
*modified = changed;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -168,7 +168,10 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
return true;
}
bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
::tensorflow::Status IdentifyDilatedConv::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
auto* stb_op = it->get();
@ -176,17 +179,17 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
// ***************************************************************************
// SpaceToBatch Op.
if (stb_op->type != OperatorType::kSpaceToBatchND) {
return false;
return ::tensorflow::Status::OK();
}
if (stb_op->inputs.size() != 3) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(stb_op->outputs.size(), 1);
// Extract the dilation factor from Input[1] of SpaceToBatch
// TODO(mjmatthews): Support 2D dilation factors.
const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
if (!block_shape_array.buffer) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
int dilation_factor =
@ -195,7 +198,7 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
// Expand Op
auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
if (!post_stb_op) {
return false;
return ::tensorflow::Status::OK();
}
bool has_expand_op = false;
if (post_stb_op->type == OperatorType::kExpandDims) {
@ -229,7 +232,8 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
}
}
return changed;
*modified = changed;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -39,7 +39,10 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
}
} // namespace
bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
::tensorflow::Status IdentifyL2Normalization::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto div_it = model->operators.begin() + op_index;
const auto* div_or_mul_op = div_it->get();
OperatorType expected_op_type_producing_div_or_mul_input;
@ -48,7 +51,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
} else if (div_or_mul_op->type == OperatorType::kMul) {
expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt;
} else {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(div_or_mul_op->inputs.size(), 2);
Operator* op_producing_div_or_mul_input[2] = {
@ -58,14 +61,14 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
if (!op_producing_div_or_mul_input[1] ||
op_producing_div_or_mul_input[1]->type !=
expected_op_type_producing_div_or_mul_input) {
return false;
return ::tensorflow::Status::OK();
}
Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
Operator* op_producing_sqrt_or_rsqrt_input =
GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
if (!op_producing_sqrt_or_rsqrt_input) {
return false;
return ::tensorflow::Status::OK();
}
// There may be an Add or a Maximum here, adding or clamping to a "small"
@ -105,7 +108,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
" because the operator producing the input to the square root, %s,"
", does not match the expected pattern",
LogName(*op_producing_sqrt_or_rsqrt_input));
return false;
return ::tensorflow::Status::OK();
}
}
@ -116,7 +119,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: "
"expected Sum op, got %s",
LogName(*sum_op));
return false;
return ::tensorflow::Status::OK();
}
Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
@ -125,7 +128,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: "
"expected Square op, got %s",
LogName(*square_op));
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(square_op->inputs.size(), 1);
@ -135,7 +138,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: %s does not "
"take the same input as the Mul/Div node",
LogName(*square_op));
return false;
return ::tensorflow::Status::OK();
}
// Create and emplace the new L2Normalization
@ -162,7 +165,8 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
model->EraseArray(div_or_mul_op->inputs[1]);
model->operators.erase(FindOperator(model, div_or_mul_op));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -38,11 +38,13 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
}
} // namespace
bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
::tensorflow::Status IdentifyL2Pool::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
const auto sqrt_it = model->operators.begin() + op_index;
const auto* sqrt_op = sqrt_it->get();
if (sqrt_op->type != OperatorType::kSqrt) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(sqrt_op->inputs.size(), 1);
@ -56,7 +58,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Giving up trying to identify L2Pool subgraph: "
"expected AveragePool op, but Sqrt op has no preceding op");
return false;
return ::tensorflow::Status::OK();
}
if (prev_to_sqrt_op->type != OperatorType::kAveragePool) {
@ -64,7 +66,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Pool subgraph: "
"expected AveragePool op, got %s",
LogName(*prev_to_sqrt_op));
return false;
return ::tensorflow::Status::OK();
}
avpool_op = static_cast<const AveragePoolOperator*>(prev_to_sqrt_op);
@ -77,7 +79,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Pool subgraph: "
"expected Square op, got %s",
LogName(*square_op));
return false;
return ::tensorflow::Status::OK();
}
// Create and emplace L2Pool node.
@ -107,7 +109,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, avpool_op));
model->operators.erase(FindOperator(model, sqrt_op));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -132,7 +132,9 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
} // namespace
bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
::tensorflow::Status IdentifyLstmCell::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
// This LSTM cell identification method is not invariant to commutation of
// commutative operator inputs. For example, if input[0] and input[1] of the
// final output multiplication were swapped, this method would not identify it
@ -143,13 +145,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
auto op_it = model->operators.begin() + op_index;
Operator* final_output_mul = op_it->get();
if (final_output_mul->type != OperatorType::kMul) {
return false;
return ::tensorflow::Status::OK();
}
Operator *state_output_tanh, *fc_output_sig;
if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
&state_output_tanh, OperatorType::kLogistic,
&fc_output_sig)) {
return false;
return ::tensorflow::Status::OK();
}
// State output TanH
@ -158,7 +160,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
Operator* state_combine_add;
if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd,
&state_combine_add)) {
return false;
return ::tensorflow::Status::OK();
}
// State forget & remember addition
@ -166,7 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul,
&state_forget_mul, OperatorType::kMul,
&state_remember_mul)) {
return false;
return ::tensorflow::Status::OK();
}
const string prev_state = state_forget_mul->inputs[0];
@ -175,7 +177,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone,
nullptr, OperatorType::kLogistic,
&state_forget_sig)) {
return false;
return ::tensorflow::Status::OK();
}
// State remember gate
@ -183,40 +185,40 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic,
&state_remember_sig, OperatorType::kTanh,
&state_info_tanh)) {
return false;
return ::tensorflow::Status::OK();
}
// State remember "information" activation function
Operator* fc_output_split;
if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit,
&fc_output_split)) {
return false;
return ::tensorflow::Status::OK();
}
// State remember gate activation function
Operator* tmp;
if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
return false;
return ::tensorflow::Status::OK();
}
// State forget gate activation function
if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
return false;
return ::tensorflow::Status::OK();
}
// Fully connected output activation function
if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
return false;
return ::tensorflow::Status::OK();
}
// Fully connected output split
Operator* fully_connected;
if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone,
nullptr, OperatorType::kFullyConnected,
&fully_connected)) {
return false;
return ::tensorflow::Status::OK();
}
// Fully connected op
@ -225,13 +227,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
OperatorType::kConcatenation, &concat_inputs,
OperatorType::kNone, nullptr, OperatorType::kNone,
nullptr)) {
return false;
return ::tensorflow::Status::OK();
}
if (static_cast<FullyConnectedOperator*>(fully_connected)->weights_format !=
FullyConnectedWeightsFormat::kDefault) {
// Not yet implemented: experimental shuffled weights in fused LSTM cell.
return false;
return ::tensorflow::Status::OK();
}
// Emplace a new LSTM cell operator
@ -300,7 +302,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, *fully_connected));
DeleteArrayIfUnused(concat_inputs->outputs[0], model);
model->operators.erase(FindOperator(model, *concat_inputs));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,19 +25,22 @@ limitations under the License.
namespace toco {
bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
::tensorflow::Status MergeLstmCellInputs::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
// Find lstm cell.
auto op_it = model->operators.begin() + op_index;
auto src_op = op_it->get();
if (src_op->type != OperatorType::kLstmCell) {
return false;
return ::tensorflow::Status::OK();
}
// Already a compact LstmCell. Do not need to merge cell inputs.
const auto* src_lstm_op = static_cast<LstmCellOperator*>(src_op);
if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL ||
src_lstm_op->inputs.size() != kExtendedLstmInputCount) {
return false;
return ::tensorflow::Status::OK();
}
// Identify prev_activ_input, prev_state_input as required Op inputs,
@ -45,12 +48,12 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
string prev_activ_input;
if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor],
&prev_activ_input)) {
return false;
return ::tensorflow::Status::OK();
}
string prev_state_input;
if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor],
&prev_state_input)) {
return false;
return ::tensorflow::Status::OK();
}
// Get LstmCell's cell, input, output size.
@ -184,7 +187,8 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model);
model->operators.erase(FindOp(*model, src_op));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,19 +25,22 @@ limitations under the License.
namespace toco {
bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
::tensorflow::Status SplitLstmCellInputs::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
// Find lstm cell.
auto op_it = model->operators.begin() + op_index;
auto curr_op = op_it->get();
if (curr_op->type != OperatorType::kLstmCell) {
return false;
return ::tensorflow::Status::OK();
}
const auto* curr_lstm_op = static_cast<LstmCellOperator*>(curr_op);
// Already an extended LstmCell. Do not need to split cell inputs.
if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC ||
curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) {
return false;
return ::tensorflow::Status::OK();
}
// Make sure the WEIGHTS_INPUT and BIASES_INPUT are constant arrays,
@ -46,13 +49,13 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
*model, curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]) ||
!IsConstantParameterArray(
*model, curr_op->inputs[LstmCellOperator::BIASES_INPUT])) {
return false;
return ::tensorflow::Status::OK();
}
// Make sure propagate_fixed_sizes has defined the size of the output.
if (!model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT])
.has_shape()) {
return false;
return ::tensorflow::Status::OK();
}
// Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc).
@ -168,7 +171,8 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
model->operators.erase(FindOp(*model, curr_op));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -43,13 +43,15 @@ limitations under the License.
namespace toco {
bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
::tensorflow::Status IdentifyPRelu::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
const auto add_op_it = model->operators.begin() + op_index;
const auto* add_op = add_op_it->get();
if (add_op == nullptr || add_op->type != OperatorType::kAdd ||
add_op->inputs.size() != 2 ||
add_op->fused_activation_function != FusedActivationFunctionType::kNone) {
return false;
return ::tensorflow::Status::OK();
}
const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]);
@ -57,7 +59,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
relu_input_op->inputs.size() != 1 ||
relu_input_op->fused_activation_function !=
FusedActivationFunctionType::kNone) {
return false;
return ::tensorflow::Status::OK();
}
// TODO(ycling): Both Add and Mul are commutative. Support the case where
@ -66,7 +68,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
if (mul_op == nullptr || mul_op->type != OperatorType::kMul ||
mul_op->inputs.size() != 2 ||
mul_op->fused_activation_function != FusedActivationFunctionType::kNone) {
return false;
return ::tensorflow::Status::OK();
}
const auto neg_alpha_tensor_name = mul_op->inputs[0];
@ -75,7 +77,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
if (relu_neg_input_op == nullptr ||
relu_neg_input_op->inputs.size() != 1) {
return false;
return ::tensorflow::Status::OK();
}
const Operator* final_input_op;
@ -92,13 +94,13 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
relu_neg_input_op->type != OperatorType::kRelu ||
relu_neg_input_op->fused_activation_function !=
FusedActivationFunctionType::kNone) {
return false;
return ::tensorflow::Status::OK();
}
final_input_op = neg_input_op;
}
if (relu_input_op->inputs[0] != final_input_op->inputs[0]) {
return false;
return ::tensorflow::Status::OK();
}
const auto input_tensor_name = relu_input_op->inputs[0];
@ -128,7 +130,8 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
// intermediate tensors aren't used by other ops, those will be removed by
// other graph transformation rules.
model->operators.erase(FindOp(*model, add_op));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -56,13 +56,15 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
}
} // namespace
bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
::tensorflow::Status IdentifyRelu1::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
// Follow sequences of min+max and max+min. First get the leading op.
const auto op_it = model->operators.begin() + op_index;
const auto* op_0 = op_it->get();
if (op_0->type != OperatorType::kMinimum &&
op_0->type != OperatorType::kMaximum) {
return false;
return ::tensorflow::Status::OK();
}
// Get the paired op and ensure it's the counter to the first.
@ -71,17 +73,17 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
(op_1->type != OperatorType::kMinimum &&
op_1->type != OperatorType::kMaximum) ||
op_0->type == op_1->type) {
return false;
return ::tensorflow::Status::OK();
}
const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1;
const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1;
if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
return false;
return ::tensorflow::Status::OK();
}
if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) {
return false;
return ::tensorflow::Status::OK();
}
// Get the original input to the min+max pair.
@ -90,7 +92,7 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
int max_scalar_input_index =
GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f);
if (min_scalar_input_index == -1 || max_scalar_input_index == -1) {
return false;
return ::tensorflow::Status::OK();
}
int op_0_scalar_input_index =
op_0 == min_op ? min_scalar_input_index : max_scalar_input_index;
@ -111,7 +113,8 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, op_0));
model->operators.erase(FindOperator(model, op_1));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -97,7 +97,10 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
return true;
}
bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
::tensorflow::Status MakeInitialDequantizeOperator::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
// This is effectively a transformation applied to edges. We iterate over the
// specified node (op) and proceed for input edges.
const auto it = model->operators.begin() + op_index;
@ -114,7 +117,8 @@ bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
}
}
}
return change_made;
*modified = change_made;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -102,18 +102,19 @@ std::vector<int32> ReshapeToTranspose(const Model& model,
// to be merged if the reshape does not affect memory ordering and does not
// affects the number of dimensions. This only occurs when only unary dimensions
// are shifting position.
bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
std::size_t op_index) {
::tensorflow::Status MergeReshapeIntoPrecedingTranspose::Run(
Model* model, std::size_t op_index, bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
it->get(), OperatorType::kReshape);
if (reshape_op == nullptr) {
return false;
return ::tensorflow::Status::OK();
}
if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
return false;
return ::tensorflow::Status::OK();
}
const string intermediate_name = reshape_op->inputs[0];
@ -121,13 +122,13 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// Guarantee the input is only consume by the reshape.
if (CountOpsWithInput(*model, intermediate_name) != 1) {
return false;
return ::tensorflow::Status::OK();
}
// Check for the parent operator.
const auto& transpose_it = FindOpWithOutput(*model, intermediate_name);
if (transpose_it == model->operators.end()) {
return false;
return ::tensorflow::Status::OK();
}
// Find the parent operator and guarantee it is a transpose.
@ -135,16 +136,16 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
transpose_it->get(), OperatorType::kTranspose);
if (transpose_op == nullptr) {
return false;
return ::tensorflow::Status::OK();
}
if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
return false;
return ::tensorflow::Status::OK();
}
if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
false /*allow_extra_unary_dimensions*/)) {
return false;
return ::tensorflow::Status::OK();
}
// Check that the intermediate is not an output array.
@ -153,7 +154,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
"Cannot fuse %s and %s as it would invalidate the transpose "
"output array.",
LogName(*transpose_op), LogName(*reshape_op));
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Merging operations %s and %s", LogName(*transpose_op),
@ -172,7 +173,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// Remove the reshape as passthrough operation.
if (!RemoveTrivialPassthroughOp(this, model, op_index)) {
return false;
return ::tensorflow::Status::OK();
}
// Update transpose_op's constant buffer to contain the new permutation.
@ -184,7 +185,8 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// transpose_ops's shape will likely has changed.
model->GetArray(transpose_op->outputs[0]).clear_shape();
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -54,7 +54,10 @@ bool IsTailOfShape(const Shape& tail, const Shape& shape) {
//
// Note we are testing for one particular case of a broader set of possible
// binary-reshape op transformations. This transformation could be generalized.
bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto binary_it = model->operators.begin() + op_index;
Operator* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
@ -69,7 +72,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kLessEqual &&
binary_op->type != OperatorType::kGreater &&
binary_op->type != OperatorType::kGreaterEqual) {
return false;
return ::tensorflow::Status::OK();
}
// BINARY OP INPUT CHECKS
@ -81,11 +84,11 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
if (!input_is_const[0] && !input_is_const[1]) {
// To limit our scope, we require one constant input. Though there's no
// reason this transformation wouldn't work with all variable inputs.
return false;
return ::tensorflow::Status::OK();
}
if (input_is_const[0] && input_is_const[1]) {
// Both inputs are constants. Leave this for constants propagation.
return false;
return ::tensorflow::Status::OK();
}
const int constant_input_idx = input_is_const[0] ? 0 : 1;
const int variable_input_idx = input_is_const[0] ? 1 : 0;
@ -98,13 +101,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not moving %s because it's non-constant input shape is not resolved.",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
if (!IsTailOfShape(
model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
model->GetArray(binary_op->inputs[variable_input_idx]).shape())) {
// Constant array shape must be the latter part of the variable shape.
return false;
return ::tensorflow::Status::OK();
}
// RESHAPE OP CHECKS
@ -113,13 +116,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
if (reshape_it == model->operators.end()) {
AddMessageF("Not moving %s because it's variable input is not connected.",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
Operator* reshape_op = reshape_it->get();
if (reshape_op->type != OperatorType::kReshape) {
AddMessageF("Not moving %s because the preceding %s is not a reshape op",
LogName(*binary_op), LogName(*reshape_op));
return false;
return ::tensorflow::Status::OK();
}
const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]);
if (!reshape_input_array.has_shape()) {
@ -127,14 +130,14 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
"Not moving %s because it's non-constant input shape is not resolved "
"yet",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
if (!IsTailOfShape(
model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
model->GetArray(reshape_op->outputs[0]).shape())) {
// Constant array shape must be the latter part of the binary op output
// shape.
return false;
return ::tensorflow::Status::OK();
}
// EXTRA CHECKS ON CONNECTING ARRAY
@ -143,7 +146,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not moving %s because the output of reshape op %s is an output op.",
LogName(*binary_op), LogName(*reshape_op));
return false;
return ::tensorflow::Status::OK();
}
}
int count_ops_consuming_output =
@ -154,7 +157,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
"Not moving %s because the output of reshape op %s is consumed by "
"another op",
LogName(*binary_op), LogName(*reshape_op));
return false;
return ::tensorflow::Status::OK();
}
// SWAP ORDER OF BINARY AND RESHAPE OPS
@ -172,7 +175,8 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
// Clear binary output shape so it will be re-propagated
model->GetArray(binary_op->outputs[0]).clear_shape();
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -26,20 +26,21 @@ limitations under the License.
namespace toco {
bool PropagateActivationFunctionIntoConstants::Run(Model* model,
std::size_t op_index) {
::tensorflow::Status PropagateActivationFunctionIntoConstants::Run(
Model* model, std::size_t op_index, bool* modified) {
*modified = false;
const auto ac_it = model->operators.begin() + op_index;
const auto* ac_op = ac_it->get();
if (ac_op->type != OperatorType::kRelu6 &&
ac_op->type != OperatorType::kRelu1 &&
ac_op->type != OperatorType::kRelu) {
return false;
return ::tensorflow::Status::OK();
}
// Find the op producing the array passed to this activation function.
auto* src_op = GetOpWithOutput(*model, ac_op->inputs[0]);
if (!src_op) {
return false;
return ::tensorflow::Status::OK();
}
// Ensure the src_op is not used without the activation function applied.
@ -57,7 +58,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
src_op_input = src_op->inputs[0];
break;
default:
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(src_op->outputs[0], ac_op->inputs[0]);
@ -69,7 +70,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
"Not propagating activation function %s into %s:%s because it is not "
"constant",
LogName(*ac_op), LogName(*src_op), src_op_input);
return false;
return ::tensorflow::Status::OK();
}
// Get the array we'll be working with and ensure it's a compatible type.
@ -79,7 +80,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
"Not propagating activation function %s into %s:%s because it is "
"non-float data",
LogName(*ac_op), LogName(*src_op), src_op_input);
return false;
return ::tensorflow::Status::OK();
}
auto& const_array_data =
const_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
@ -108,14 +109,15 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
}
default:
LOG(FATAL) << "Unsupported activation function " << LogName(*ac_op);
return false;
return ::tensorflow::Status::OK();
}
const_array_data[i] = new_value;
}
AddMessageF("Propagated activation function %s into %s:%s", LogName(*ac_op),
LogName(*src_op), src_op_input);
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -32,7 +32,10 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
}
} // namespace
bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status PropagateArrayDataTypes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
@ -40,7 +43,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
for (const auto& input : op->inputs) {
if (!model->IsOptionalArray(input) &&
model->GetArray(input).data_type == ArrayDataType::kNone) {
return false;
return ::tensorflow::Status::OK();
}
}
// Record data types of output before processing, so we can see at the
@ -131,7 +134,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
auto* rand_op = static_cast<RandomUniformOperator*>(op);
// The output type of RandomUniform is specified with an attribute
if (rand_op->dtype == ArrayDataType::kNone) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(op->outputs.size(), 1);
SetDataTypeForAllOutputs(model, op, rand_op->dtype);
@ -153,7 +156,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// This can make unsupported_op->output_data_types have more elements than
// op->outputs.
if (unsupported_op->output_data_types.size() < op->outputs.size()) {
return false;
return ::tensorflow::Status::OK();
}
for (int i = 0; i < op->outputs.size(); ++i) {
const string& output = op->outputs[i];
@ -164,7 +167,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
}
case OperatorType::kExpandDims: {
// Yield on ExpandDim until it is converted to Reshape
return false;
return ::tensorflow::Status::OK();
}
case OperatorType::kSelect: {
// Select produces outputs with the same type as their 2nd input
@ -248,10 +251,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
if (old_output_data_types[output] != model->GetArray(output).data_type) {
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
}
return false;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -39,7 +39,10 @@ bool SupportsMinMax(const Array& array) {
// When provided a set of min/max values for uint8 arrays this will rescale
// the values for other data types as required and preserving the floating point
// range within the new type.
bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) {
::tensorflow::Status PropagateDefaultMinMax::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
@ -61,7 +64,8 @@ bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) {
}
}
return did_change;
*modified = did_change;
return ::tensorflow::Status::OK();
}
// Sets the min/max on the given array, adjusting the reference_minmax for the

View File

@ -277,11 +277,14 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation,
// nice logging and integration with the graphviz video dumping mode.
// In general you should not copy this style of transformation and stick to
// local-only changes as seen in the other transformations.
bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
::tensorflow::Status PropagateFakeQuantNumBits::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
if (op->type != OperatorType::kFakeQuant) {
return false;
return ::tensorflow::Status::OK();
}
auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
@ -290,7 +293,7 @@ bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
&quantized_data_type)) {
AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring",
LogName(*op), fakequant_op->num_bits);
return false;
return ::tensorflow::Status::OK();
}
const auto& final_minmax = *fakequant_op->minmax;
@ -311,7 +314,8 @@ bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
did_change |=
RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type);
return did_change;
*modified = did_change;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -1622,7 +1622,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
std::unordered_map<string, std::vector<int>> old_output_dims;
@ -1836,7 +1839,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<TensorFlowUnsupportedOperator*>(op);
// Attribute can be not specified, ignore it.
if (unsupported_op->output_shapes.size() < op->outputs.size()) {
return false;
return ::tensorflow::Status::OK();
}
for (int i = 0; i < op->outputs.size(); ++i) {
const string& output = op->outputs[i];
@ -1886,10 +1889,11 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
(old_output_dims[output] != model->GetArray(output).shape().dims())) {
AddMessageF("Set shape of %s to [%s]", output,
absl::StrJoin(model->GetArray(output).shape().dims(), ","));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
}
return false;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -439,7 +439,9 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
} // namespace
bool Quantize::Run(Model* model, std::size_t op_index) {
::tensorflow::Status Quantize::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
// Our general "quantization" graph transformation consists in replacing
// QuantizedInputArrays[] ->
// DequantizeOperators[] ->
@ -460,7 +462,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
auto& op = *model->operators[op_index];
if (op.type == OperatorType::kDequantize ||
op.type == OperatorType::kFakeQuant) {
return false;
return ::tensorflow::Status::OK();
}
// Our assumption here is that the input arrays are already quantized -
@ -497,7 +499,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
if (!array.minmax && !array.buffer) {
LOG(ERROR) << "Can't quantize input array " << input
<< " because it lacks min/max info";
return false;
return ::tensorflow::Status::OK();
}
const auto* other_op = GetOpWithOutput(*model, input);
if (other_op && other_op->type != OperatorType::kDequantize) {
@ -507,7 +509,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
"which means that we should yield and let other ops "
"get quantized first",
LogName(op), input);
return false;
return ::tensorflow::Status::OK();
}
}
}
@ -672,7 +674,8 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
}
}
return changed;
*modified = changed;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -51,18 +51,19 @@ bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model,
} // end namespace
bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
std::size_t op_index) {
::tensorflow::Status ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(
Model* model, std::size_t op_index, bool* modified) {
*modified = false;
const auto fakequant_it = model->operators.begin() + op_index;
auto* fakequant_base_op = fakequant_it->get();
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
return false;
return ::tensorflow::Status::OK();
}
auto* fq_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
if (!fq_op->minmax) {
// Need to be resolved first by ResolveFakeQuantArgsFromVars.
return false;
return ::tensorflow::Status::OK();
}
// At this point, this FakeQuantOperator should have a MinMax
@ -74,7 +75,8 @@ bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
bool changed = false;
changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]);
changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
return changed;
*modified = changed;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,11 +25,14 @@ limitations under the License.
namespace toco {
bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveFinalDequantizeOp::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto dequantize_it = model->operators.begin() + op_index;
const auto* dequantize_op = dequantize_it->get();
if (dequantize_op->type != OperatorType::kDequantize) {
return false;
return ::tensorflow::Status::OK();
}
const auto& output = dequantize_op->outputs[0];
// We can remove any dequantize op whose output is not consumed by
@ -38,7 +41,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
// in the middle of the graph might be designated as an output
// array.
if (CountOpsWithInput(*model, output)) {
return false;
return ::tensorflow::Status::OK();
}
// If one of the model's output arrays was actually the Dequantize op's
@ -53,7 +56,8 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
AddMessageF("Removed final %s", LogName(*dequantize_op));
model->EraseArray(output);
model->operators.erase(dequantize_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -23,11 +23,14 @@ limitations under the License.
namespace toco {
bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTensorFlowAssert::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto assert_it = model->operators.begin() + op_index;
const auto* assert_op = assert_it->get();
if (assert_op->type != OperatorType::kAssert) {
return false;
return ::tensorflow::Status::OK();
}
bool changed = false;
@ -54,7 +57,8 @@ bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
// That's it. We can stop here, no need to duplicate the work that
// RemoveUnusedOp will do removing this now-unused node.
return changed;
*modified = changed;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,14 +25,18 @@ limitations under the License.
namespace toco {
bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTensorFlowIdentity::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto passthru_it = model->operators.begin() + op_index;
const auto* passthru_op = passthru_it->get();
if (passthru_op->type != OperatorType::kIdentity) {
return false;
return ::tensorflow::Status::OK();
}
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -46,14 +46,17 @@ bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
// For example, an Add operator is trivial if
// one of its operands is constant 0, a Mul operator is trivial
// if one of its operands is constant 1, etc.
bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTrivialBinaryOperator::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@ -66,12 +69,12 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can resolve here.
return false;
return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
return false;
return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@ -84,7 +87,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
const auto& input_array_1 = model->GetArray(binary_op->inputs[1]);
if (!input_array_0.has_shape() || !input_array_1.has_shape()) {
// Both input shapes must be known.
return false;
return ::tensorflow::Status::OK();
}
if (input_array_0.shape().dimensions_count() ==
input_array_1.shape().dimensions_count() &&
@ -94,7 +97,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
"(lhs %s, rhs %s)",
LogName(*binary_op), ShapeToString(input_array_0.shape()),
ShapeToString(input_array_1.shape()));
return false;
return ::tensorflow::Status::OK();
}
// Now check if the constant operand makes this binary
@ -103,7 +106,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
model->GetArray(binary_op->inputs[index_of_constant_input]);
// For now, we only handle floats here.
if (constant_input_array.data_type != ArrayDataType::kFloat) {
return false;
return ::tensorflow::Status::OK();
}
const auto& constant_input_float_data =
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
@ -121,12 +124,13 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
}
if (!is_trivial) {
return false;
return ::tensorflow::Status::OK();
}
// Now we know that this node is trivial, so we can remove it.
AddMessageF("Removing trivial %s", LogName(*binary_op));
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,16 +25,20 @@ limitations under the License.
namespace toco {
bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTrivialConcatenation::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto concat_it = model->operators.begin() + op_index;
auto* concat_op = concat_it->get();
if (concat_op->type != OperatorType::kConcatenation) {
return false;
return ::tensorflow::Status::OK();
}
if (concat_op->inputs.size() != 1) {
return false;
return ::tensorflow::Status::OK();
}
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,7 +25,10 @@ limitations under the License.
namespace toco {
bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTrivialConcatenationInput::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
// TensorFlow allows Concatenation nodes to have 0-D inputs,
// and they are then treated as empty i.e. omitted from concatenation,
// in violation of the notion that 0-D is equivalent to 1x1x1x1.
@ -36,7 +39,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
const auto concat_it = model->operators.begin() + op_index;
auto* concat_op = concat_it->get();
if (concat_op->type != OperatorType::kConcatenation) {
return false;
return ::tensorflow::Status::OK();
}
std::vector<string> trivial_inputs;
std::vector<string> nontrivial_inputs;
@ -52,7 +55,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
}
if (trivial_inputs.empty()) {
return false;
return ::tensorflow::Status::OK();
}
// Drop trivial inputs.
@ -63,7 +66,8 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
}
}
concat_op->inputs = nontrivial_inputs;
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -64,23 +64,27 @@ bool IsFakeQuantTrivial(GraphTransformation* transformation, const Model& model,
} // namespace
// Removes FakeQuant ops that are trivial (have no effect, are redundant, etc).
bool RemoveTrivialFakeQuant::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTrivialFakeQuant::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto op_it = model->operators.begin() + op_index;
auto* op = op_it->get();
if (op->type != OperatorType::kFakeQuant) {
return false;
return ::tensorflow::Status::OK();
}
auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
if (!IsFakeQuantTrivial(this, *model, *fakequant_op)) {
AddMessageF("%s is not trivial", LogName(*fakequant_op));
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Removing trivial %s", LogName(*fakequant_op));
CHECK_EQ(fakequant_op->inputs.size(), 1);
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -94,12 +94,13 @@ bool IsTrivialFusedActivationFunc(
// Attempts to remove both fused and unfused activation functions if the
// quantization params indicate that the representable values fall inside the
// activation range.
bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
std::size_t op_index) {
::tensorflow::Status RemoveTrivialQuantizedActivationFunc::Run(
Model* model, std::size_t op_index, bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
if (op->inputs.empty()) {
return false;
return ::tensorflow::Status::OK();
}
if (IsTrivialUnfusedActivationFunc(this, *model, op->type, op->inputs[0])) {
@ -107,7 +108,8 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
"Removing trivial unfused activation function %s because the input "
"minmax imply at least as tight a clamp anyway.",
LogName(*op));
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
if (IsTrivialFusedActivationFunc(this, *model, op->fused_activation_function,
op->outputs[0])) {
@ -117,9 +119,10 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
"because the output quantization parameters imply at least as tight "
"a clamp anyway.",
LogName(*op));
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
return false;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -69,22 +69,26 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
// Attempts to remove min/max functions if the quantization params indicate that
// the representable values fall inside the clip range.
bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTrivialQuantizedMinMax::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
if ((op->type != OperatorType::kMinimum &&
op->type != OperatorType::kMaximum) ||
op->inputs.size() != 2) {
return false;
return ::tensorflow::Status::OK();
}
if (IsTrivialMinMax(this, *model, op->type, op->inputs[0], op->inputs[1])) {
AddMessageF(
"Removing trivial min/max %s because the quantization parameters imply "
"at least as tight a clamp anyway.",
LogName(*op));
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
return false;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -81,22 +81,26 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
} // namespace
bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTrivialReshape::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto reshape_it = model->operators.begin() + op_index;
auto* reshape_op = reshape_it->get();
if (reshape_op->type != OperatorType::kReshape) {
return false;
return ::tensorflow::Status::OK();
}
if (!IsReshapeTrivial(*model, *reshape_op, this)) {
AddMessageF("%s is not trivial", LogName(*reshape_op));
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Removing trivial %s", LogName(*reshape_op));
CHECK_EQ(reshape_op->inputs.size(), 2);
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -49,21 +49,24 @@ bool IsSliceTrivial(const Model& model, const Operator& op,
} // namespace
bool RemoveTrivialSlice::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveTrivialSlice::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
const auto reshape_it = model->operators.begin() + op_index;
auto* slice_op = reshape_it->get();
if (slice_op->type != OperatorType::kSlice) {
return false;
return ::tensorflow::Status::OK();
}
if (!IsSliceTrivial(*model, *slice_op, this)) {
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Removing trivial %s", LogName(*slice_op));
CHECK_EQ(slice_op->inputs.size(), 3);
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,7 +25,9 @@ limitations under the License.
namespace toco {
bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
::tensorflow::Status RemoveUnusedOp::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
@ -58,7 +60,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
}
for (const string& output_array : model->flags.output_arrays()) {
if (output == output_array) {
return false;
return ::tensorflow::Status::OK();
}
}
for (const auto& rnn_state : model->flags.rnn_states()) {
@ -67,19 +69,19 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
if (!IsDiscardableArray(*model, rnn_state.back_edge_source_array()) ||
!IsDiscardableArray(*model, rnn_state.state_array()) ||
CountOpsWithInput(*model, rnn_state.state_array())) {
return false;
return ::tensorflow::Status::OK();
}
}
}
if (CountOpsWithInput(*model, output)) {
return false;
return ::tensorflow::Status::OK();
}
}
if (op->unresolved_outputs) {
AddMessageF("Not discarding %s because it has unresolved outputs.",
LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Discarding %s because none of its outputs is used.",
@ -105,7 +107,8 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
}
}
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -63,29 +63,32 @@ bool IsMoveOperator(OperatorType optype) {
// Swap elementwise operators such that all value operators occur before all
// element move operators, e.g. negation then transpose.
bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ReorderElementwiseUnary::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto element_op_it = model->operators.begin() + op_index;
std::unique_ptr<Operator>& element_op = *element_op_it;
if (!IsElementwiseOperator(element_op->type)) {
return false;
return ::tensorflow::Status::OK();
}
const string intermediate_name = element_op->inputs[0];
auto it = FindOpWithOutput(*model, intermediate_name);
if (it == model->operators.end()) {
AddMessageF("No preceding operator");
return false;
return ::tensorflow::Status::OK();
}
std::unique_ptr<Operator>& move_op = *it;
if (!IsMoveOperator(move_op->type)) {
AddMessageF("Preceding operator is not a move operator");
return false;
return ::tensorflow::Status::OK();
}
if (CountOpsWithInput(*model, intermediate_name) != 1) {
AddMessageF("Input %s used elsewhere", intermediate_name);
return false;
return ::tensorflow::Status::OK();
}
// Check that the intermediate is discardable.
@ -94,7 +97,7 @@ bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
"Cannot swap elementwise as it would invalidate %s which is "
"an output array.",
intermediate_name);
return false;
return ::tensorflow::Status::OK();
}
// op->inputs may change so we need to keep a value by copy.
@ -147,7 +150,8 @@ bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
// Swap the order of the operators.
element_op.swap(move_op);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -101,37 +101,40 @@ std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
// Swaps reshape-transpose to transpose-reshape whenever possible. This is
// possible when the reshape does not affect memory ordering.
bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ReorderReshapeTranspose::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto transpose_it = model->operators.begin() + op_index;
TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
transpose_it->get(), OperatorType::kTranspose);
if (transpose_op == nullptr) {
return false;
return ::tensorflow::Status::OK();
}
if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
// Wait for values to propagate.
return false;
return ::tensorflow::Status::OK();
}
// Find the operator that produces the transpose op.
auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]);
if (reshape_it == model->operators.end()) {
return false;
return ::tensorflow::Status::OK();
}
TensorFlowReshapeOperator* reshape_op =
ConvertOperator<TensorFlowReshapeOperator*>(reshape_it->get(),
OperatorType::kReshape);
if (reshape_op == nullptr) {
return false;
return ::tensorflow::Status::OK();
}
// Ignore if the reshape is uninitialized.
if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
return false;
return ::tensorflow::Status::OK();
}
// Need to copy to keep static if permutated.
@ -142,7 +145,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
// Intermediate should not be consumed by any other operators.
if (CountOpsWithInput(*model, intermediate_name) != 1) {
AddMessageF("Input %s used elsewhere", intermediate_name);
return false;
return ::tensorflow::Status::OK();
}
// Check that the intermediate is not an output array.
@ -151,7 +154,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
"Cannot reorder reshape-transpose as it would invalidate %s which is "
"an output array.",
intermediate_name);
return false;
return ::tensorflow::Status::OK();
}
// Get the arrays.
@ -173,7 +176,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
// dimensions then it can be moved between the transpose.
if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
true /*allow_extra_unary_dims*/)) {
return false;
return ::tensorflow::Status::OK();
}
if (!IsDiscardableArray(*model, output_name)) {
@ -242,7 +245,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
// Swap the order of the operators.
transpose_it->swap(*reshape_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,10 +25,13 @@ limitations under the License.
namespace toco {
bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveBatchNormalization::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto bn_it = model->operators.begin() + op_index;
if (bn_it->get()->type != OperatorType::kBatchNormalization) {
return false;
return ::tensorflow::Status::OK();
}
const auto* bn_op =
static_cast<const BatchNormalizationOperator*>(bn_it->get());
@ -53,7 +56,7 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
// so we need to exit early if these buffers don't exist (i.e. if the params
// haven't yet been resolved as constants).
if (!mean_array.buffer || !multiplier_array.buffer || !offset_array.buffer) {
return false;
return ::tensorflow::Status::OK();
}
// Create the new Mul, Add operators
@ -142,7 +145,8 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
DCHECK_EQ(bn_it->get(), bn_op);
model->operators.erase(bn_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,31 +24,35 @@ limitations under the License.
namespace toco {
bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveBatchToSpaceNDAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto op_it = model->operators.begin() + op_index;
if (op_it->get()->type != OperatorType::kBatchToSpaceND) return false;
if (op_it->get()->type != OperatorType::kBatchToSpaceND)
return ::tensorflow::Status::OK();
auto* op = static_cast<BatchToSpaceNDOperator*>(op_it->get());
// The attributes are resolved only when the 3 attributes (block_shape,
// before_crops, after_crops) are all constant.
if (!op->block_shape.empty()) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(op->inputs.size(), 3);
if (!IsConstantParameterArray(*model, op->inputs[1]) ||
!IsConstantParameterArray(*model, op->inputs[2]))
return false;
return ::tensorflow::Status::OK();
// Handle crops
const auto& crops_array = model->GetArray(op->inputs[2]);
if (!crops_array.has_shape()) return false;
if (!crops_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& crops_dims = crops_array.shape().dims();
if (crops_dims.size() != 2) {
// Code only handles crops of 2 dimensions. Perhaps another transformation
// will delete this op.
return false;
return ::tensorflow::Status::OK();
}
const std::vector<int>& crops_buffer =
crops_array.GetBuffer<ArrayDataType::kInt32>().data;
@ -59,7 +63,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
// Handle block_shape
const auto& block_shape_array = model->GetArray(op->inputs[1]);
if (!block_shape_array.has_shape()) return false;
if (!block_shape_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
const std::vector<int>& block_shape_buffer =
@ -68,7 +72,8 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
op->block_shape.push_back(block_shape_buffer[i]);
}
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -188,7 +188,10 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
}
} // namespace
bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto binary_it = model->operators.begin() + op_index;
const auto* binary_op = binary_it->get();
// Test for binary ops of types that we know how to resolve
@ -204,7 +207,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kLessEqual &&
binary_op->type != OperatorType::kGreater &&
binary_op->type != OperatorType::kGreaterEqual) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@ -212,13 +215,13 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
const auto& input1_array = model->GetArray(binary_op->inputs[1]);
// Check if both inputs are constant parameters.
if (!input0_array.buffer || !input1_array.buffer) {
return false;
return ::tensorflow::Status::OK();
}
auto& output_array = model->GetArray(binary_op->outputs[0]);
// Yield until the output array dims have been resolved.
if (!output_array.has_shape()) {
return false;
return ::tensorflow::Status::OK();
}
// At the moment we don't want to care about fused activation functions.
@ -229,7 +232,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not resolving constant %s because it has a fused activation function",
LogName(*binary_op));
return false;
return ::tensorflow::Status::OK();
}
// Check that input data types agree.
@ -253,7 +256,8 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*binary_op));
model->operators.erase(binary_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -135,11 +135,14 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation,
} // namespace
// Resolves the concatenation operator if all its inputs are constant arrays.
bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantConcatenation::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto concat_it = model->operators.begin() + op_index;
const auto* concat_base_op = concat_it->get();
if (concat_base_op->type != OperatorType::kConcatenation) {
return false;
return ::tensorflow::Status::OK();
}
const auto* concat_op =
static_cast<const ConcatenationOperator*>(concat_base_op);
@ -149,11 +152,15 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// We also make sure the shapes of the input arrays are known and they are
// all discardable.
const Operator* input_op = GetOpWithOutput(*model, input_name);
if (input_op) return false;
if (!IsConstantParameterArray(*model, input_name)) return false;
if (!model->GetArray(input_name).has_shape()) return false;
if (model->GetArray(input_name).quantization_params) return false;
if (!IsDiscardableArray(*model, input_name)) return false;
if (input_op) return ::tensorflow::Status::OK();
if (!IsConstantParameterArray(*model, input_name))
return ::tensorflow::Status::OK();
if (!model->GetArray(input_name).has_shape())
return ::tensorflow::Status::OK();
if (model->GetArray(input_name).quantization_params)
return ::tensorflow::Status::OK();
if (!IsDiscardableArray(*model, input_name))
return ::tensorflow::Status::OK();
}
const int concatenation_axis = concat_op->axis;
@ -205,7 +212,8 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// Remove concatenate operator.
model->operators.erase(concat_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -59,11 +59,14 @@ void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type,
}
}
bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantFakeQuant::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto fakequant_it = model->operators.begin() + op_index;
const auto* fakequant_base_op = fakequant_it->get();
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
return false;
return ::tensorflow::Status::OK();
}
const auto* fakequant_op =
@ -71,12 +74,12 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
// Yield until the fakequant MinMax has been resolved.
if (!fakequant_op->minmax) {
return false;
return ::tensorflow::Status::OK();
}
// This transformation only applies when the input array is constant.
if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) {
return false;
return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
@ -87,7 +90,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
&quantized_data_type)) {
AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits);
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Resolving constant %s", LogName(*fakequant_op));
@ -136,7 +139,8 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
}
model->operators.erase(fakequant_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -41,11 +41,14 @@ bool ComputeFillArray(Model* model, FillOperator* op) {
return true;
}
bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantFill::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto fill_it = model->operators.begin() + op_index;
auto* base_op = fill_it->get();
if (base_op->type != OperatorType::kFill) {
return false;
return ::tensorflow::Status::OK();
}
auto* op = static_cast<FillOperator*>(base_op);
@ -55,44 +58,44 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
return false;
return ::tensorflow::Status::OK();
}
const auto& val_array = model->GetArray(op->inputs[1]);
if (!val_array.has_shape()) {
// Yield until the value shape has been resolved.
return false;
return ::tensorflow::Status::OK();
}
if (!IsConstantParameterArray(*model, op->inputs[1])) {
// Yield until the value is constant.
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1);
switch (output_array.data_type) {
case ArrayDataType::kFloat:
if (!ComputeFillArray<ArrayDataType::kFloat>(model, op)) {
return false;
return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kUint8:
if (!ComputeFillArray<ArrayDataType::kUint8>(model, op)) {
return false;
return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt32:
if (!ComputeFillArray<ArrayDataType::kInt32>(model, op)) {
return false;
return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt64:
if (!ComputeFillArray<ArrayDataType::kInt64>(model, op)) {
return false;
return ::tensorflow::Status::OK();
}
break;
default:
@ -114,7 +117,8 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(fill_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -61,11 +61,14 @@ inline void Gather(const Array& input_array, int input_rank,
// Resolves a constant Gather operation.
// This simply performs the gather and produces the output array with the
// appropriate values.
bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantGather::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kGather) {
return false;
return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const GatherOperator*>(base_op);
@ -74,28 +77,28 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
return false;
return ::tensorflow::Status::OK();
}
if (!op->axis) {
// Yield until axis has been set by ResolveGatherAttributes.
return false;
return ::tensorflow::Status::OK();
}
if (op->axis.value() != 0) {
// Only handling axis=0 for now.
AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op),
op->axis.value());
return false;
return ::tensorflow::Status::OK();
}
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
return false;
return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
const Array& coords_array = model->GetArray(op->inputs[1]);
@ -142,7 +145,8 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -49,11 +49,14 @@ void Pack(Model* model, PackOperator const& op) {
} // namespace
bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantPack::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kPack) {
return false;
return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const PackOperator*>(base_op);
@ -62,18 +65,18 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
return false;
return ::tensorflow::Status::OK();
}
for (const auto& input : op->inputs) {
if (!IsConstantParameterArray(*model, input)) {
// Yield if any input is mutable
return false;
return ::tensorflow::Status::OK();
}
}
@ -111,7 +114,8 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -59,11 +59,14 @@ bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) {
return true;
}
bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
auto* base_op = it->get();
if (base_op->type != OperatorType::kRandomUniform) {
return false;
return ::tensorflow::Status::OK();
}
auto* op = static_cast<RandomUniformOperator*>(base_op);
@ -73,12 +76,12 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
return false;
return ::tensorflow::Status::OK();
}
if ((op->seed == 0) && (op->seed2 == 0)) {
@ -86,13 +89,13 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
<< "\" is truly random (using /dev/random system entropy). "
"Therefore, cannot resolve as constant. Set \"seed\" or "
"\"seed2\" attr non-zero to fix this";
return false;
return ::tensorflow::Status::OK();
}
switch (output_array.data_type) {
case ArrayDataType::kFloat:
if (!ComputeRandomUniformArray<ArrayDataType::kFloat>(model, op)) {
return false;
return ::tensorflow::Status::OK();
}
break;
// For future support of double or half.
@ -110,7 +113,8 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -19,11 +19,14 @@ limitations under the License.
namespace toco {
bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantRange::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
auto* base_op = it->get();
if (base_op->type != OperatorType::kRange) {
return false;
return ::tensorflow::Status::OK();
}
auto* op = static_cast<RangeOperator*>(base_op);
@ -31,23 +34,23 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
return ::tensorflow::Status::OK();
}
const auto& limit_array = model->GetArray(op->inputs[1]);
if (!limit_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
return ::tensorflow::Status::OK();
}
const auto& delta_array = model->GetArray(op->inputs[2]);
if (!delta_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
return ::tensorflow::Status::OK();
}
for (const auto& input : op->inputs) {
if (!IsConstantParameterArray(*model, input)) {
// yield if any input is mutable
return false;
return ::tensorflow::Status::OK();
}
}
@ -55,7 +58,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
@ -101,7 +104,8 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
// Delete the operator
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -22,11 +22,14 @@ limitations under the License.
namespace toco {
// Resolves a constant reshape operation by copying the buffer.
bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantReshape::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kReshape) {
return false;
return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op);
@ -36,17 +39,17 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
return false;
return ::tensorflow::Status::OK();
}
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
return false;
return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
@ -54,7 +57,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
AddMessageF("Constant reshape is non-trivial (%s -> %s)",
ShapeToString(input_array.shape()),
ShapeToString(output_array.shape()));
return false;
return ::tensorflow::Status::OK();
}
CHECK(!output_array.buffer);
@ -95,7 +98,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
default:
LOG(FATAL) << "Unsupported data type: "
<< ArrayDataTypeName(input_array.data_type);
return false;
return ::tensorflow::Status::OK();
}
AddMessageF("Resolving constant reshape of %s", LogName(*op));
@ -112,7 +115,8 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -27,11 +27,14 @@ namespace toco {
// This implementation is looking strictly for all-or-nothing on the select
// condition. It's possible to enhance this by looking per-element and possibly
// producing a Mul op.
bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantSelect::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kSelect) {
return false;
return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const SelectOperator*>(base_op);
@ -40,23 +43,23 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
return false;
return ::tensorflow::Status::OK();
}
// We require the cond input to be constant.
if (!IsConstantParameterArray(*model, op->inputs[0])) {
return false;
return ::tensorflow::Status::OK();
}
const Array& cond_array = model->GetArray(op->inputs[0]);
CHECK(cond_array.data_type == ArrayDataType::kBool)
<< "Only bool conditions are supported";
const auto& cond_data = cond_array.GetBuffer<ArrayDataType::kBool>().data;
if (cond_data.empty()) {
return false;
return ::tensorflow::Status::OK();
}
// Check if the condition is the same for all elements.
@ -67,12 +70,14 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
"Cannot resolve %s as constant; cond_array has differing "
"per-element values",
LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
}
// Pass-through the selected input.
return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
*modified =
RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -19,29 +19,32 @@ limitations under the License.
namespace toco {
bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been resolved
return false;
return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until the input array's shape has been resolved.
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been resolved.
return false;
return ::tensorflow::Status::OK();
}
// Compute the output
@ -65,7 +68,8 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
}
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -86,11 +86,14 @@ bool Slice(SliceOperator const& op, Array const& input_array,
} // namespace
bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantSlice::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kSlice) {
return false;
return ::tensorflow::Status::OK();
}
const SliceOperator* op = static_cast<const SliceOperator*>(base_op);
@ -99,49 +102,49 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
return false;
return ::tensorflow::Status::OK();
}
if (op->begin.empty() || op->size.empty()) {
// Attributes have not resolved yet.
return false;
return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until the value shape has been resolved.
return false;
return ::tensorflow::Status::OK();
}
if (!IsConstantParameterArray(*model, op->inputs[0])) {
// Yield until the value is constant.
return false;
return ::tensorflow::Status::OK();
}
CHECK(!output_array.buffer);
switch (output_array.data_type) {
case ArrayDataType::kFloat:
if (!Slice<ArrayDataType::kFloat>(*op, input_array, &output_array)) {
return false;
return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kUint8:
if (!Slice<ArrayDataType::kUint8>(*op, input_array, &output_array)) {
return false;
return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt32:
if (!Slice<ArrayDataType::kInt32>(*op, input_array, &output_array)) {
return false;
return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt64:
if (!Slice<ArrayDataType::kInt64>(*op, input_array, &output_array)) {
return false;
return ::tensorflow::Status::OK();
}
break;
default:
@ -159,7 +162,8 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -103,11 +103,14 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
} // anonymous namespace
bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kStridedSlice) {
return false;
return ::tensorflow::Status::OK();
}
const StridedSliceOperator* op =
@ -117,28 +120,28 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
return false;
return ::tensorflow::Status::OK();
}
if (op->start_indices.empty() || op->stop_indices.empty() ||
op->strides.empty()) {
// Attributes have not resolved yet.
return false;
return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until the value shape has been resolved.
return false;
return ::tensorflow::Status::OK();
}
if (!IsConstantParameterArray(*model, op->inputs[0])) {
// Yield until the value is constant.
return false;
return ::tensorflow::Status::OK();
}
CHECK(!output_array.buffer);
@ -164,7 +167,8 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
DeleteOpAndArraysIfUnused(model, it->get());
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -97,11 +97,14 @@ inline void Tile(const Array& input_array, const Array& multiples_array,
} // namespace
// Resolves a constant Tile operation.
bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantTile::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kTile) {
return false;
return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const TensorFlowTileOperator*>(base_op);
@ -110,17 +113,17 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
return false;
return ::tensorflow::Status::OK();
}
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
return false;
return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
const Array& multiples_array = model->GetArray(op->inputs[1]);
@ -159,7 +162,8 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -101,11 +101,14 @@ void Transpose(Model* model, const Array& input_array,
} // namespace
bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantTranspose::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kTranspose) {
return false;
return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const TransposeOperator*>(base_op);
@ -114,17 +117,17 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
return false;
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
return false;
return ::tensorflow::Status::OK();
}
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
return false;
return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
@ -132,7 +135,7 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
if (op->perm.empty()) {
// Yield until perm has been populated by ResolveTransposeAttributes.
return false;
return ::tensorflow::Status::OK();
}
// We currently only support 1-4 dimensions.
@ -174,7 +177,8 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -112,7 +112,10 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
return true;
}
bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto unary_it = model->operators.begin() + op_index;
const auto* unary_op = unary_it->get();
// Test for unary ops of types that we know how to resolve.
@ -133,28 +136,28 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
case OperatorType::kRelu:
break;
default:
return false;
return ::tensorflow::Status::OK();
}
// Check if the input is a constant parameter.
if (!IsConstantParameterArray(*model, unary_op->inputs[0])) {
return false;
return ::tensorflow::Status::OK();
}
// if the unary op involves a tensor required by a rnn state, ignore it
for (const auto& rnn_state : model->flags.rnn_states()) {
if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) {
return false;
return ::tensorflow::Status::OK();
}
if (unary_op->inputs[0] == rnn_state.state_array()) {
return false;
return ::tensorflow::Status::OK();
}
}
auto& output_array = model->GetArray(unary_op->outputs[0]);
if (!output_array.has_shape()) {
// Yield until the output array dims have been resolved.
return false;
return ::tensorflow::Status::OK();
}
// At the moment we don't want to care about fused activation functions.
@ -166,7 +169,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
"Not resolving constant %s "
" because it has a fused activation function",
LogName(*unary_op));
return false;
return ::tensorflow::Status::OK();
}
// The min-max is only copied for ops that copy data without arithmetic.
@ -187,7 +190,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
"Not resolving constant %s because we currently only support casting "
"to float",
LogName(*unary_op));
return false;
return ::tensorflow::Status::OK();
}
if (cast_op->src_data_type != input_array.buffer->type) {
AddMessageF(
@ -197,7 +200,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
} else {
if (input_array.buffer->type != ArrayDataType::kFloat) {
return false;
return ::tensorflow::Status::OK();
}
input_float_data = &(input_array.GetBuffer<ArrayDataType::kFloat>().data);
}
@ -239,7 +242,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs";
if (!IsConstantParameterArray(*model, unary_op->inputs[1])) {
AddMessageF("Axis input is non-constant");
return false;
return ::tensorflow::Status::OK();
}
auto& axis_array = model->GetArray(unary_op->inputs[1]);
CHECK(axis_array.data_type == ArrayDataType::kInt32);
@ -336,7 +339,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
default:
LOG(FATAL) << "Unsupported activation function "
<< LogName(*unary_op);
return false;
return ::tensorflow::Status::OK();
}
output_float_data[i] = new_value;
}
@ -351,7 +354,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*unary_op));
model->operators.erase(unary_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,17 +25,20 @@ limitations under the License.
namespace toco {
bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveFakeQuantArgsFromVars::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto fakequant_it = model->operators.begin() + op_index;
auto* fakequant_base_op = fakequant_it->get();
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
return false;
return ::tensorflow::Status::OK();
}
auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
if (fakequant_op->minmax) {
// Already resolved.
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(fakequant_op->inputs.size(), 3);
@ -43,7 +46,7 @@ bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
// resolved to constant arrays.
for (int i = 1; i <= 2; i++) {
if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) {
return false;
return ::tensorflow::Status::OK();
}
}
@ -74,7 +77,8 @@ bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model);
}
fakequant_op->inputs.resize(1);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,20 +24,25 @@ limitations under the License.
namespace toco {
bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveGatherAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto* gather_op = model->operators[op_index].get();
if (gather_op->type != OperatorType::kGather) return false;
if (gather_op->type != OperatorType::kGather)
return ::tensorflow::Status::OK();
auto* op = static_cast<GatherOperator*>(gather_op);
if (op->axis) {
// Attributes already resolved
return false;
return ::tensorflow::Status::OK();
}
if (op->inputs.size() != 3) return false;
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
if (op->inputs.size() != 3) return ::tensorflow::Status::OK();
if (!IsConstantParameterArray(*model, op->inputs[2]))
return ::tensorflow::Status::OK();
const auto& indices_array = model->GetArray(op->inputs[2]);
if (!indices_array.has_shape()) return false;
if (!indices_array.has_shape()) return ::tensorflow::Status::OK();
const auto& axis_data = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
CHECK_EQ(axis_data.size(), 1)
<< "Multidimensional gather not supported on " << LogName(*op);
@ -47,7 +52,8 @@ bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUsedOnce(op->inputs[2], model);
op->inputs.resize(2);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -51,27 +51,30 @@ void FillArrayWithZeros(Array* array) {
// Removes a multiplication by array of constant zeros by making the output
// array an array of constant zeros and removing the input arrays if they are no
// longer needed.
bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveMultiplyByZero::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto mul_it = model->operators.begin() + op_index;
auto* mul_op = mul_it->get();
if (mul_op->type != OperatorType::kMul) {
return false;
return ::tensorflow::Status::OK();
}
const auto& output_array_name = mul_op->outputs[0];
auto& output_array = model->GetArray(output_array_name);
if (!IsDiscardableArray(*model, output_array_name)) {
return false;
return ::tensorflow::Status::OK();
}
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
return ::tensorflow::Status::OK();
}
// Yield if the output shape is not known yet.
if (!output_array.has_shape()) {
return false;
return ::tensorflow::Status::OK();
}
// This transformation only handles the case where one operand is all 0's and
@ -83,12 +86,12 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can resolve here.
return false;
return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants propagation, not
// for us to handle here.
return false;
return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@ -105,7 +108,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>(
constant_input_data)) {
return false;
return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kFloat>(&output_array);
} break;
@ -114,7 +117,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kUint8>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>(
constant_input_data)) {
return false;
return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kUint8>(&output_array);
} break;
@ -123,7 +126,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>(
constant_input_data)) {
return false;
return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kInt32>(&output_array);
} break;
@ -132,14 +135,14 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kInt64>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>(
constant_input_data)) {
return false;
return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kInt64>(&output_array);
} break;
default:
AddMessageF(
"Cannot resolve multiply by 0 because of unsupported data type\n");
return false;
return ::tensorflow::Status::OK();
}
// Erase input arrays to the multiply if no longer used
@ -149,7 +152,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
// Erase the multiply operator.
model->operators.erase(mul_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,19 +24,23 @@ limitations under the License.
namespace toco {
bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolvePadAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto pad_it = model->operators.begin() + op_index;
auto* pad_op = pad_it->get();
if (pad_op->type != OperatorType::kPad) return false;
if (pad_op->type != OperatorType::kPad) return ::tensorflow::Status::OK();
auto* op = static_cast<PadOperator*>(pad_op);
if (!op->left_padding.empty()) return false;
if (!op->left_padding.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 2);
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[1]))
return ::tensorflow::Status::OK();
const auto& array = model->GetArray(op->inputs[1]);
if (!array.has_shape()) return false;
if (!array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& dims = array.shape().dims();
CHECK_EQ(dims.size(), 2);
@ -50,6 +54,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
// TODO(dkalenichenko): Delete the extra input?
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,19 +24,23 @@ limitations under the License.
namespace toco {
bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolvePadV2Attributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto pad_it = model->operators.begin() + op_index;
auto* pad_op = pad_it->get();
if (pad_op->type != OperatorType::kPadV2) return false;
if (pad_op->type != OperatorType::kPadV2) return ::tensorflow::Status::OK();
auto* op = static_cast<PadV2Operator*>(pad_op);
if (!op->left_padding.empty()) return false;
if (!op->left_padding.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 3);
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[1]))
return ::tensorflow::Status::OK();
const auto& array = model->GetArray(op->inputs[1]);
if (!array.has_shape()) return false;
if (!array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& dims = array.shape().dims();
CHECK_EQ(dims.size(), 2);
@ -50,6 +54,7 @@ bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) {
// TODO(dkalenichenko): Delete the extra input?
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -39,23 +39,37 @@ bool ResolveAttributes(Model* model, T* op) {
return true;
}
bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveReduceAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
Operator* op = model->operators[op_index].get();
switch (op->type) {
case OperatorType::kMean:
return ResolveAttributes(model, static_cast<MeanOperator*>(op));
*modified = ResolveAttributes(model, static_cast<MeanOperator*>(op));
return ::tensorflow::Status::OK();
case OperatorType::kSum:
return ResolveAttributes(model, static_cast<TensorFlowSumOperator*>(op));
*modified =
ResolveAttributes(model, static_cast<TensorFlowSumOperator*>(op));
return ::tensorflow::Status::OK();
case OperatorType::kReduceProd:
return ResolveAttributes(model, static_cast<TensorFlowProdOperator*>(op));
*modified =
ResolveAttributes(model, static_cast<TensorFlowProdOperator*>(op));
return ::tensorflow::Status::OK();
case OperatorType::kReduceMin:
return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
*modified =
ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
return ::tensorflow::Status::OK();
case OperatorType::kReduceMax:
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
*modified =
ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
return ::tensorflow::Status::OK();
case OperatorType::kAny:
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
*modified =
ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
return ::tensorflow::Status::OK();
default:
return false;
return ::tensorflow::Status::OK();
}
}

View File

@ -78,11 +78,13 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
}
}
bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveReorderAxes::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
if (op->type != OperatorType::kReorderAxes) {
return false;
return ::tensorflow::Status::OK();
}
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
@ -93,11 +95,11 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
auto& input_array = model->GetArray(input_array_name);
auto& output_array = model->GetArray(output_array_name);
if (!input_array.buffer) {
return false;
return ::tensorflow::Status::OK();
}
// Yield until output dims have been resolved.
if (!output_array.has_shape()) {
return false;
return ::tensorflow::Status::OK();
}
// Reorder the input array dims and buffer data
if (input_array.buffer->type == ArrayDataType::kFloat) {
@ -120,7 +122,8 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
DeleteOpAndArraysIfUnused(model, op);
RenameArray(model, output_array_name, input_array_name);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,25 +25,29 @@ limitations under the License.
namespace toco {
bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveReshapeAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto reshape_it = model->operators.begin() + op_index;
auto* reshape_op = reshape_it->get();
if (reshape_op->type != OperatorType::kReshape) {
return false;
return ::tensorflow::Status::OK();
}
auto* op = static_cast<TensorFlowReshapeOperator*>(reshape_op);
if (!op->shape.empty()) return false;
if (!op->shape.empty()) return ::tensorflow::Status::OK();
if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]);
op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
}
if (op->shape.empty()) return false;
if (op->shape.empty()) return ::tensorflow::Status::OK();
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,29 +24,35 @@ limitations under the License.
namespace toco {
bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveSliceAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto slice_it = model->operators.begin() + op_index;
auto* slice_op = slice_it->get();
if (slice_op->type != OperatorType::kSlice) return false;
if (slice_op->type != OperatorType::kSlice) return ::tensorflow::Status::OK();
auto* op = static_cast<SliceOperator*>(slice_op);
if (!op->begin.empty()) return false;
if (!op->begin.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 3);
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
if (!IsConstantParameterArray(*model, op->inputs[1]))
return ::tensorflow::Status::OK();
if (!IsConstantParameterArray(*model, op->inputs[2]))
return ::tensorflow::Status::OK();
const auto& begin_array = model->GetArray(op->inputs[1]);
if (!begin_array.has_shape()) return false;
if (!begin_array.has_shape()) return ::tensorflow::Status::OK();
const auto& size_array = model->GetArray(op->inputs[2]);
if (!size_array.has_shape()) return false;
if (!size_array.has_shape()) return ::tensorflow::Status::OK();
op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
op->size = size_array.GetBuffer<ArrayDataType::kInt32>().data;
// TODO(dkalenichenko): Delete the extra inputs?
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,16 +24,20 @@ limitations under the License.
namespace toco {
bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveSpaceToBatchNDAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto op_it = model->operators.begin() + op_index;
if (op_it->get()->type != OperatorType::kSpaceToBatchND) return false;
if (op_it->get()->type != OperatorType::kSpaceToBatchND)
return ::tensorflow::Status::OK();
auto* op = static_cast<SpaceToBatchNDOperator*>(op_it->get());
// The attributes are resolved only when the 3 attributes (block_shape,
// before_paddings, after_paddings) are all constant.
if (!op->block_shape.empty()) {
return false;
return ::tensorflow::Status::OK();
}
const int block_shape_index = 1;
@ -42,16 +46,16 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 3);
if (!IsConstantParameterArray(*model, op->inputs[block_shape_index]) ||
!IsConstantParameterArray(*model, op->inputs[paddings_index]))
return false;
return ::tensorflow::Status::OK();
// Handle paddings.
const auto& paddings_array = model->GetArray(op->inputs[paddings_index]);
if (!paddings_array.has_shape()) return false;
if (!paddings_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& paddings_dims = paddings_array.shape().dims();
if (paddings_dims.size() != 2) {
// Code only handles padding of 2 dimensions. Perhaps another transformation
// will delete this op.
return false;
return ::tensorflow::Status::OK();
}
const std::vector<int>& paddings_buffer =
paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
@ -63,7 +67,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
// Handle block_shape.
const auto& block_shape_array =
model->GetArray(op->inputs[block_shape_index]);
if (!block_shape_array.has_shape()) return false;
if (!block_shape_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
const std::vector<int>& block_shape_buffer =
@ -72,7 +76,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
op->block_shape.push_back(block_shape_buffer[i]);
}
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,10 +25,13 @@ limitations under the License.
namespace toco {
bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveSqueezeAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto* squeeze_op = model->operators[op_index].get();
if (squeeze_op->type != OperatorType::kSqueeze) {
return false;
return ::tensorflow::Status::OK();
}
DCHECK_EQ(squeeze_op->inputs.size(), 1);
DCHECK_EQ(squeeze_op->outputs.size(), 1);
@ -42,10 +45,11 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
"Reshape op",
LogName(*squeeze_op));
return RemoveTrivialPassthroughOp(this, model, op_index);
*modified = RemoveTrivialPassthroughOp(this, model, op_index);
return ::tensorflow::Status::OK();
}
}
return false;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -37,40 +37,47 @@ int PadAttributeArray(Array* attribute_array, std::vector<int> pad_values,
return mask;
}
bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveStridedSliceAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto slice_it = model->operators.begin() + op_index;
auto* slice_op = slice_it->get();
if (slice_op->type != OperatorType::kStridedSlice) return false;
if (slice_op->type != OperatorType::kStridedSlice)
return ::tensorflow::Status::OK();
auto* op = static_cast<StridedSliceOperator*>(slice_op);
if (!op->start_indices.empty()) {
// We have already resolved these attributes
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(op->inputs.size(), 4);
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// We require the dimensionality of the input to pad the indices
return false;
return ::tensorflow::Status::OK();
}
auto& start_array = model->GetArray(op->inputs[1]);
if (!start_array.has_shape()) return false;
if (!start_array.has_shape()) return ::tensorflow::Status::OK();
if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
// Only 1-4D arrays are supported for now.
return false;
return ::tensorflow::Status::OK();
}
auto& stop_array = model->GetArray(op->inputs[2]);
if (!stop_array.has_shape()) return false;
if (!stop_array.has_shape()) return ::tensorflow::Status::OK();
auto& stride_array = model->GetArray(op->inputs[3]);
if (!stride_array.has_shape()) return false;
if (!stride_array.has_shape()) return ::tensorflow::Status::OK();
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
if (!IsConstantParameterArray(*model, op->inputs[1]))
return ::tensorflow::Status::OK();
if (!IsConstantParameterArray(*model, op->inputs[2]))
return ::tensorflow::Status::OK();
if (!IsConstantParameterArray(*model, op->inputs[3]))
return ::tensorflow::Status::OK();
int num_input_axes = input_array.shape().dimensions_count();
int start_indices_size = start_array.shape().dims(0);
@ -112,6 +119,7 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -25,12 +25,15 @@ limitations under the License.
namespace toco {
bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveTensorFlowConcat::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto concat_it = model->operators.begin() + op_index;
const auto* tf_concat_op = concat_it->get();
if (tf_concat_op->type != OperatorType::kConcat &&
tf_concat_op->type != OperatorType::kConcatV2) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_GE(tf_concat_op->inputs.size(), 2);
@ -54,7 +57,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
if (!axis_array.buffer) {
AddMessageF("Waiting for the axis of %s to be resolved to a constant",
LogName(*tf_concat_op));
return false;
return ::tensorflow::Status::OK();
}
CHECK(axis_array.data_type == ArrayDataType::kInt32);
@ -79,7 +82,8 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
}
// Remove the TensorFlowConcat op
model->operators.erase(concat_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -55,10 +55,13 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
} // namespace
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveTensorFlowMatMul::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto matmul_it = model->operators.begin() + op_index;
if (matmul_it->get()->type != OperatorType::kMatMul) {
return false;
return ::tensorflow::Status::OK();
}
const auto* matmul_op =
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
@ -73,7 +76,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
"Not replacing %s by a FullyConnected operator, because it has "
"the transpose_a attribute",
LogName(*matmul_op));
return false;
return ::tensorflow::Status::OK();
}
// Reorder the axes on the second input. TensorFlow uses row-major ordering
@ -198,7 +201,8 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
// erase the MatMul operator
model->operators.erase(matmul_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,11 +24,14 @@ limitations under the License.
namespace toco {
bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveTensorFlowMerge::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto merge_it = model->operators.begin() + op_index;
const auto* merge_op = merge_it->get();
if (merge_op->type != OperatorType::kMerge) {
return false;
return ::tensorflow::Status::OK();
}
// We need to yield until this Merge node has only 1 input, which will mean
@ -37,7 +40,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
// non-selected inputs, so that at some point there will be only 1 input left.
if (merge_op->inputs.size() > 1) {
AddMessageF("Waiting for %s to be resolved", LogName(*merge_op));
return false;
return ::tensorflow::Status::OK();
}
// Now that the merge node has 1 input exactly, it is the same as an Identity
@ -57,7 +60,8 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
AddMessageF("Removing already-resolved %s", LogName(*merge_op));
model->EraseArray(merge_op->outputs[0]);
model->operators.erase(merge_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,11 +24,14 @@ limitations under the License.
namespace toco {
bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveTensorFlowSwitch::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto switch_it = model->operators.begin() + op_index;
const auto* switch_op = switch_it->get();
if (switch_op->type != OperatorType::kSwitch) {
return false;
return ::tensorflow::Status::OK();
}
CHECK_EQ(switch_op->inputs.size(), 2);
@ -40,7 +43,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Waiting for the boolean predicate of %s to be resolved to a constant",
LogName(*switch_op));
return false;
return ::tensorflow::Status::OK();
}
// The predicate should be boolean, and should consist of a single value.
@ -119,7 +122,8 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
// Remove the switch node itself.
AddMessageF("Removing already-resolved %s", LogName(*switch_op));
model->operators.erase(switch_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,19 +24,24 @@ limitations under the License.
namespace toco {
bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ResolveTransposeAttributes::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto op_it = model->operators.begin() + op_index;
if (op_it->get()->type != OperatorType::kTranspose) return false;
if (op_it->get()->type != OperatorType::kTranspose)
return ::tensorflow::Status::OK();
auto* op = static_cast<TransposeOperator*>(op_it->get());
if (!op->perm.empty()) return false;
if (!op->perm.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 2);
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[1]))
return ::tensorflow::Status::OK();
// Handling perm.
const auto& perm_array = model->GetArray(op->inputs[1]);
if (!perm_array.has_shape()) return false;
if (!perm_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& perm_dims = perm_array.shape().dims();
CHECK_EQ(perm_dims.size(), 1);
@ -47,7 +52,8 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
op->perm.push_back(perm_buffer[i]);
}
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -24,15 +24,17 @@ limitations under the License.
namespace toco {
bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
Operator* op = model->operators[op_index].get();
if (op->type != OperatorType::kFullyConnected) {
return false;
return ::tensorflow::Status::OK();
}
FullyConnectedOperator* fc_op = static_cast<FullyConnectedOperator*>(op);
// Exit if this FC op already has shuffled weights
if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) {
return false;
return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(fc_op->inputs[0]);
const string& weights_name = fc_op->inputs[1];
@ -46,11 +48,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
output_array.data_type != ArrayDataType::kInt16 ||
!input_array.quantization_params || !weights_array.quantization_params ||
!output_array.quantization_params) {
return false;
return ::tensorflow::Status::OK();
}
// Exit if the shapes aren't known
if (!input_array.has_shape() || !weights_array.has_shape()) {
return false;
return ::tensorflow::Status::OK();
}
// Exit if, based on the known shapes, this FC op is not a GEMV.
// The shuffling of FC weights is only useful to enable fast GEMV paths.
@ -64,7 +66,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"the input shape is not 1D or 2D (possibly with additional inner "
"dimensions of size 1)",
LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
}
if (input_shape.dims(0) != 1 && input_shape.dims(0) != 4) {
@ -73,7 +75,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"the input shape's leading dimension, i.e. the 'batch size', is not "
"equal to 1 or 4",
LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
// Exit if the weights shape isn't an integral multiple of the shuffled
// block shape, 4x16. We don't want to have to write code dealing with
@ -88,7 +90,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
// two.
const Shape& weights_shape = weights_array.shape();
if (weights_shape.dimensions_count() != 2) {
return false;
return ::tensorflow::Status::OK();
}
const int rows = weights_shape.dims(0);
const int cols = weights_shape.dims(1);
@ -97,11 +99,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"Not applying experimental shuffling to the weights of %s because its "
"shape isn't a multiple of the shuffling block shape, 4x16",
LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
// Exit if the weights aren't already a constant array.
if (!weights_array.buffer) {
return false;
return ::tensorflow::Status::OK();
}
// Exit if the weights are used by more than one op.
if (CountOpsWithInput(*model, weights_name) != 1) {
@ -109,7 +111,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"Not applying experimental shuffling to the weights of %s because that "
"array is consumed by other operators",
LogName(*op));
return false;
return ::tensorflow::Status::OK();
}
// Compute the shuffled weights
auto& weights_data =
@ -152,7 +154,8 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
shuffled_input_workspace_array.GetOrCreateQuantizationParams() =
input_array.GetQuantizationParams();
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -166,7 +166,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
bool modified;
ASSERT_TRUE((*graph_transformation_set.begin())
->Run(&model, /*op_index=*/0, &modified)
.ok());
EXPECT_THAT(model.GetArrayMap().size(), 1);
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
@ -185,7 +188,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
bool modified;
ASSERT_TRUE((*graph_transformation_set.begin())
->Run(&model, /*op_index=*/0, &modified)
.ok());
EXPECT_THAT(model.GetArrayMap().size(), 1);
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
@ -204,7 +210,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
bool modified;
ASSERT_TRUE((*graph_transformation_set.begin())
->Run(&model, /*op_index=*/0, &modified)
.ok());
EXPECT_THAT(model.GetArrayMap().size(), 1);
auto& concatenated_array = (*model.GetArrayMap().begin()).second;

View File

@ -50,7 +50,8 @@ void RunResolveSum(const std::vector<float>& input,
sum_op->inputs = {"input0", "input1"};
sum_op->outputs = {"output"};
model.operators.push_back(std::move(sum_op));
ResolveConstantUnaryOperator().Run(&model, 0);
bool modified;
ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok());
EXPECT_EQ(model.GetArray("output").GetBuffer<ArrayDataType::kFloat>().data,
expected_output);
EXPECT_EQ(model.GetArray("output").shape().dims(), output_shape);

View File

@ -25,13 +25,16 @@ limitations under the License.
namespace toco {
bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
::tensorflow::Status UnfuseActivationFunctions::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
// If a conv operation has an im2col array, yield: it should be dropped first.
if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) {
return false;
return ::tensorflow::Status::OK();
}
Operator* ac_op = nullptr;
@ -46,7 +49,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
ac_op = new Relu1Operator;
break;
default:
return false;
return ::tensorflow::Status::OK();
}
// At this point we know that the op has a fused activation function. At the
@ -74,7 +77,8 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
ac_op->inputs = {tmp_array_name};
op->outputs = {tmp_array_name};
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -22,7 +22,10 @@ limitations under the License.
namespace toco {
bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
// Collapses a partitioned tf.nn.embedding_lookup back into a single Gather.
// https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup
// This transform attempts to identify the len(params) > 1 case and collapse
@ -47,7 +50,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
// First look for the final DynamicStitch.
auto op_it = model->operators.begin() + op_index;
if (op_it->get()->type != OperatorType::kDynamicStitch) {
return false;
return ::tensorflow::Status::OK();
}
auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get());
@ -72,7 +75,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because indices input %s into "
"%s is unexpected",
LogName(*op), LogName(*stitch_op));
return false;
return ::tensorflow::Status::OK();
}
if (!indices_partition_op) {
indices_partition_op = static_cast<DynamicPartitionOperator*>(op);
@ -83,7 +86,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because indices input %s into "
"%s is from a different source op than others",
LogName(*op), LogName(*stitch_op));
return false;
return ::tensorflow::Status::OK();
}
}
}
@ -92,12 +95,12 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
// The data for the indices must be a constant range of the array shape.
if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) {
AddMessageF("Skipping because indices partition data is non-constant");
return false;
return ::tensorflow::Status::OK();
}
auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]);
if (indices_data_array.data_type == ArrayDataType::kNone) {
// Yield until data types are propagated.
return false;
return ::tensorflow::Status::OK();
}
CHECK(indices_data_array.data_type == ArrayDataType::kInt32)
<< "Indices partition inputs must be int32";
@ -117,7 +120,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into %s "
"is unexpected",
LogName(*op), LogName(*stitch_op));
return false;
return ::tensorflow::Status::OK();
}
gather_ops.push_back(static_cast<GatherOperator*>(op));
}
@ -132,7 +135,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into "
"%s is unexpected",
LogName(*op), LogName(*gather_op));
return false;
return ::tensorflow::Status::OK();
}
if (!data_partition_op) {
data_partition_op = static_cast<DynamicPartitionOperator*>(op);
@ -143,7 +146,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into "
"%s is from a different source op than others",
LogName(*op), LogName(*gather_op));
return false;
return ::tensorflow::Status::OK();
}
}
}
@ -236,7 +239,8 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
DeleteOpAndArraysIfUnused(model, indices_partition_op);
DeleteOpAndArraysIfUnused(model, data_partition_op);
DeleteOpAndArraysIfUnused(model, stitch_op);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco

View File

@ -36,10 +36,12 @@ namespace toco {
// slice_c = tf.matmul(slice_a, slice_b)
// result_slices[bat] = slice_c
// result = tf.stack(result_slices)
bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
bool* modified) {
*modified = false;
auto batch_op_it = model->operators.begin() + op_index;
if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
return false;
return ::tensorflow::Status::OK();
}
const auto* batch_op =
static_cast<const BatchMatMulOperator*>(batch_op_it->get());
@ -47,7 +49,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
// We must have the shape of at least one input to know our batch size.
const auto& input_array_a = model->GetArray(batch_op->inputs[0]);
const auto& input_array_b = model->GetArray(batch_op->inputs[1]);
if (!input_array_a.has_shape() || !input_array_b.has_shape()) return false;
if (!input_array_a.has_shape() || !input_array_b.has_shape())
return ::tensorflow::Status::OK();
// We only support the rank 3 case. If you are batching on rank > 3 you'll
// have to figure that out.
@ -66,7 +69,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
batch_op_it = matmul_op_it + 1;
CHECK_EQ(batch_op_it->get(), batch_op);
model->operators.erase(batch_op_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
CHECK_EQ(input_array_a.shape().dimensions_count(), 3)
<< "Input arrays must have rank 3";
@ -167,7 +171,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
CHECK(batch_op_it != model->operators.end());
CHECK(batch_op_it->get() == batch_op);
model->operators.erase(batch_op_it);
return true;
*modified = true;
return ::tensorflow::Status::OK();
}
} // namespace toco