Merge pull request #26326 from tensorflow/2.0-ff
2.0 ff , Move r2.0 branch ahaed to pick up test and build fixes.
This commit is contained in:
commit
bdecee4c43
@ -39,14 +39,19 @@ filegroup(
|
||||
"python_api.h",
|
||||
"*test*",
|
||||
],
|
||||
),
|
||||
) + [
|
||||
"//tensorflow/cc:srcs",
|
||||
"//tensorflow/core/distributed_runtime:server_lib.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = ["c_api.h"],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
"//tensorflow/c:__subpackages__",
|
||||
@ -68,7 +73,9 @@ tf_cuda_library(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
hdrs = ["c_api.h"],
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
@ -89,9 +96,7 @@ tf_cuda_library(
|
||||
"c_api.cc",
|
||||
"c_api_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"c_api.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [":c_api_internal"] + select({
|
||||
|
@ -8,6 +8,19 @@ package(
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
filegroup(
|
||||
name = "srcs",
|
||||
srcs = [
|
||||
"framework/gradients.h",
|
||||
"framework/ops.h",
|
||||
"framework/scope.h",
|
||||
"framework/scope_internal.h",
|
||||
"ops/array_ops.h",
|
||||
"ops/while_loop.h",
|
||||
"//tensorflow/cc/saved_model:loader.h",
|
||||
],
|
||||
)
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"cc_library_with_android_deps",
|
||||
@ -606,16 +619,13 @@ tf_gen_op_wrappers_cc(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
cc_library_with_android_deps(
|
||||
cc_library(
|
||||
name = "cc_op_gen_main",
|
||||
srcs = [
|
||||
"framework/cc_op_gen.cc",
|
||||
"framework/cc_op_gen.h",
|
||||
"framework/cc_op_gen_main.cc",
|
||||
],
|
||||
android_deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
data = [
|
||||
"//tensorflow/core/api_def:base_api_def",
|
||||
|
@ -1432,6 +1432,30 @@ Status CheckInputsWeights(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AllowDataTypes(const OpConverterParams& params,
|
||||
const std::set<DataType>& allowed_dtypes) {
|
||||
const auto& node_def = params.node_def;
|
||||
TFAttrs attrs(params.node_def);
|
||||
if (attrs.count("T")) {
|
||||
const auto op_dtype = attrs.get<DataType>("T");
|
||||
if (!allowed_dtypes.count(op_dtype)) {
|
||||
// Build string list of allowed types.
|
||||
std::ostringstream ss;
|
||||
for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) {
|
||||
if (it != allowed_dtypes.begin()) ss << ", ";
|
||||
ss << DataTypeString(*it);
|
||||
}
|
||||
return errors::Unimplemented("Data type ", DataTypeString(op_dtype),
|
||||
" is not supported for ", node_def.op(),
|
||||
", must be one of [", ss.str(), "], at ",
|
||||
node_def.name());
|
||||
}
|
||||
}
|
||||
// If there is no T attribute, we can't determine the type of the op. We will
|
||||
// allow it to convert for now.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
|
||||
const TRT_ShapedWeights& weights_src) {
|
||||
auto dtype_new = DataType::DT_HALF;
|
||||
@ -1734,6 +1758,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
CheckInputsWeights(*params, {{"input", false}, {"filter", true}}));
|
||||
tensor = inputs.at(0).tensor();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
|
||||
if (weights_rsck.shape_.nbDims != 4) {
|
||||
return errors::InvalidArgument("Conv2D expects kernel of dimension 4, at " +
|
||||
@ -1996,6 +2022,8 @@ Status ConvertTranspose(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckInputsWeights(*params, {{"x", false}, {"perm", true}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
// Get the permutation from weights.
|
||||
TRT_ShapedWeights weights = inputs.at(1).weights();
|
||||
const int* weights_ptr =
|
||||
@ -2030,6 +2058,8 @@ Status ConvertReshape(OpConverterParams* params) {
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
TRT_TensorOrWeights input_tensor = inputs.at(0);
|
||||
TRT_ShapedWeights weights = inputs.at(1).weights();
|
||||
if (weights.count() == 0) {
|
||||
@ -2127,6 +2157,8 @@ Status ConvertExpandDims(OpConverterParams* params) {
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
// Get input shape as vector.
|
||||
TRT_TensorOrWeights input_tensor = inputs.at(0);
|
||||
const nvinfer1::Dims dims = input_tensor.GetTrtDims();
|
||||
@ -2177,6 +2209,8 @@ Status ConvertSqueeze(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
// Get input shape.
|
||||
TRT_TensorOrWeights input_tensor = inputs.at(0);
|
||||
const nvinfer1::Dims dims = input_tensor.GetTrtDims();
|
||||
@ -2439,6 +2473,8 @@ Status ConvertSlice(OpConverterParams* params) {
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(
|
||||
*params, {{"input", false}, {"begin", true}, {"size", true}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
std::vector<int> begin = inputs.at(1).weights().ToVector<int>();
|
||||
std::vector<int> size = inputs.at(2).weights().ToVector<int>();
|
||||
// Get input dims.
|
||||
@ -2483,6 +2519,8 @@ Status ConvertStridedSlice(OpConverterParams* params) {
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(
|
||||
*params,
|
||||
{{"input", false}, {"begin", true}, {"end", true}, {"strides", true}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
// Get input dims.
|
||||
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
|
||||
std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
|
||||
@ -2578,6 +2616,8 @@ Status ConvertPool(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
nvinfer1::PoolingType type;
|
||||
if (node_def.op() == "MaxPool") {
|
||||
type = nvinfer1::PoolingType::kMAX;
|
||||
@ -2669,15 +2709,10 @@ Status ConvertPool(OpConverterParams* params) {
|
||||
Status ConvertLeakyRelu(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
if (inputs.size() != 1) {
|
||||
return errors::InvalidArgument(node_def.op(), " expects one input, at ",
|
||||
node_def.name());
|
||||
}
|
||||
if (!inputs.at(0).is_tensor()) {
|
||||
return errors::Unimplemented(node_def.op(),
|
||||
" is only implemented for tensors, at ",
|
||||
node_def.name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
|
||||
TFAttrs attrs(node_def);
|
||||
const float alpha = attrs.get<float>("alpha");
|
||||
if (alpha < 0.0f || alpha > 1.0f) {
|
||||
@ -2719,6 +2754,8 @@ Status ConvertActivation(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
static const std::unordered_map<string, nvinfer1::ActivationType> ops{
|
||||
{"Relu", nvinfer1::ActivationType::kRELU},
|
||||
{"Sigmoid", nvinfer1::ActivationType::kSIGMOID},
|
||||
@ -2815,6 +2852,8 @@ Status ConvertRelu6(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
if (params->validation_only) return Status::OK();
|
||||
// ***************************************************************************
|
||||
// TensorRT does not implement Relu6 natively. This function converts Relu6 op
|
||||
@ -2864,18 +2903,14 @@ Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckInputsWeights(*params, {{"value", false}, {"bias", true}}));
|
||||
TFAttrs attrs(node_def);
|
||||
DataType tf_dtype = attrs.get<DataType>("T");
|
||||
if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) {
|
||||
return errors::Unimplemented("Data type is not supported, for node ",
|
||||
node_def.name(), " got ",
|
||||
DataTypeString(tf_dtype));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
nvinfer1::ITensor* tensor =
|
||||
const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor());
|
||||
const nvinfer1::Dims original_dims = tensor->getDimensions();
|
||||
TFAttrs attrs(node_def);
|
||||
const string data_format = attrs.get<string>("data_format");
|
||||
const int channel_index =
|
||||
(data_format == "NHWC" ? original_dims.nbDims - 1 : 0);
|
||||
@ -3092,6 +3127,8 @@ Status ConvertBinary(OpConverterParams* params) {
|
||||
" inputs but expected 2, at ",
|
||||
node_def.name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
|
||||
// Constant folding should have been done by TensorFlow
|
||||
if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
|
||||
@ -3133,6 +3170,8 @@ Status ConvertRsqrt(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// TODO(tmorris): params->converter is null during validation. Allow
|
||||
@ -3199,6 +3238,8 @@ Status ConvertUnary(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
auto op_pair = UnaryOperationMap()->find(node_def.op());
|
||||
if (op_pair == UnaryOperationMap()->end()) {
|
||||
return errors::Unimplemented("Unary op: ", node_def.op(),
|
||||
@ -3236,27 +3277,21 @@ Status ConvertSquare(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// Constant 2 with same rank as input
|
||||
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
|
||||
for (int i = 0; i < dims.nbDims; i++) {
|
||||
dims.d[i] = 1;
|
||||
}
|
||||
TRT_ShapedWeights weights =
|
||||
params->weight_store->GetTempWeights(DataType::DT_FLOAT, dims);
|
||||
auto weights_ptr =
|
||||
static_cast<float*>(const_cast<void*>(weights.GetValues()));
|
||||
weights_ptr[0] = 2.f;
|
||||
nvinfer1::ITensor* const2_tensor =
|
||||
params->converter->CreateConstantLayer(weights, dims);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(const2_tensor, node_def.name());
|
||||
const nvinfer1::ITensor* const2_tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
|
||||
params, 2.0f, inputs.at(0).GetTrtDims(), &const2_tensor));
|
||||
|
||||
// ElementWise Pow Operation
|
||||
nvinfer1::IElementWiseLayer* layer =
|
||||
params->converter->network()->addElementWise(
|
||||
*const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()),
|
||||
*const2_tensor, nvinfer1::ElementWiseOperation::kPOW);
|
||||
*const_cast<nvinfer1::ITensor*>(const2_tensor),
|
||||
nvinfer1::ElementWiseOperation::kPOW);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||
|
||||
@ -3269,6 +3304,8 @@ Status ConvertReduce(OpConverterParams* params) {
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
|
||||
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
||||
TRT_ShapedWeights index_list = inputs.at(1).weights();
|
||||
@ -3330,6 +3367,8 @@ Status ConvertPad(OpConverterParams* params) {
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
|
||||
// Implement tensor binaryOp weight [channel wise] for now;
|
||||
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
||||
@ -3432,6 +3471,10 @@ Status ConvertPad(OpConverterParams* params) {
|
||||
Status ConvertConcat(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
// TODO(tmorris): There is a bug with Concat and INT32 in TRT - it is supposed
|
||||
// to be supported.
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
// not including the last input (axis) here
|
||||
int input_size = static_cast<int>(inputs.size()) - 1;
|
||||
|
||||
@ -3514,6 +3557,8 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
|
||||
{"offset", true},
|
||||
{"mean", true},
|
||||
{"variance", true}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
TFAttrs attrs(node_def);
|
||||
float epsilon = attrs.get<float>("epsilon");
|
||||
auto data_format = attrs.get<string>("data_format");
|
||||
@ -3645,6 +3690,8 @@ Status ConvertGather(OpConverterParams* params) {
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(
|
||||
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
|
||||
if (axis.size() != 1) {
|
||||
return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
|
||||
@ -3714,14 +3761,10 @@ Status ConvertMatMul(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"a", false}, {"b", true}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
|
||||
TFAttrs attrs(node_def);
|
||||
DataType tf_dtype = attrs.get<DataType>("T");
|
||||
if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) {
|
||||
return errors::Unimplemented("Data type is not supported, for node ",
|
||||
node_def.name(), " got ",
|
||||
DataTypeString(tf_dtype));
|
||||
}
|
||||
bool transpose_a = attrs.get<bool>("transpose_a");
|
||||
bool transpose_b = attrs.get<bool>("transpose_b");
|
||||
|
||||
@ -3742,6 +3785,8 @@ Status ConvertBatchMatMul(OpConverterParams* params) {
|
||||
// TODO(tmorris): Enable once false is updated to mean either tensor or weight
|
||||
// TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
|
||||
// false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
if (inputs.size() != 2) {
|
||||
return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
|
||||
" inputs but expected 2, at ",
|
||||
@ -3752,14 +3797,6 @@ Status ConvertBatchMatMul(OpConverterParams* params) {
|
||||
"All inputs are weights, but Grappler is expected to fold them.");
|
||||
}
|
||||
TFAttrs attrs(node_def);
|
||||
|
||||
const DataType tf_dtype = attrs.get<DataType>("T");
|
||||
if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) {
|
||||
return errors::Unimplemented("data type is not supported, for node ",
|
||||
node_def.name(),
|
||||
" got " + DataTypeString(tf_dtype));
|
||||
}
|
||||
|
||||
const bool transpose_a = attrs.get<bool>("adj_x");
|
||||
const bool transpose_b = attrs.get<bool>("adj_y");
|
||||
const auto dims = inputs.at(0).GetTrtDims();
|
||||
@ -3815,6 +3852,8 @@ Status ConvertSoftmax(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
||||
|
||||
int nbDims = tensor->getDimensions().nbDims;
|
||||
@ -3840,15 +3879,11 @@ Status ConvertSoftmax(OpConverterParams* params) {
|
||||
|
||||
Status ConvertTopK(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
|
||||
!inputs.at(1).is_weights()) {
|
||||
return errors::InvalidArgument("Input expects tensor and weights, at ",
|
||||
params->node_def.name());
|
||||
}
|
||||
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckInputsWeights(*params, {{"input", false}, {"k", true}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
||||
const int num_dims = tensor->getDimensions().nbDims;
|
||||
if (num_dims == 0) {
|
||||
|
@ -1571,9 +1571,9 @@ TEST_F(OpConverterTest, ConvertMatMul) {
|
||||
NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32);
|
||||
AddTestWeights<int32>("weights", {2, 1}, {3, 5});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Data type is not supported, for node my_matmul got int32");
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"Data type int32 is not supported for MatMul, "
|
||||
"must be one of [float, half], at my_matmul");
|
||||
}
|
||||
// transpose_a is set.
|
||||
for (bool transpose_b : {false, true}) {
|
||||
@ -3328,8 +3328,9 @@ TEST_F(OpConverterTest, ConvertTopK) {
|
||||
{
|
||||
// Input list is empty, should fail.
|
||||
NodeDef node_def = MakeNodeDef("my_topk", "TopKV2", {});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Input expects tensor and weights, at my_topk");
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"TopKV2 got 0 inputs but expected 2, at my_topk");
|
||||
}
|
||||
|
||||
for (const auto dtype : {DT_FLOAT, DT_INT32}) {
|
||||
@ -3346,8 +3347,8 @@ TEST_F(OpConverterTest, ConvertTopK) {
|
||||
/*trt_dtype=*/TfDataTypeToTrt(dtype));
|
||||
AddTestTensor("weights", {2});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Input expects tensor and weights, at my_topk");
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"The input \"k\" for TopKV2 must be a constant, at my_topk");
|
||||
}
|
||||
{
|
||||
// Ok.
|
||||
|
@ -79,7 +79,10 @@ static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||
XLA_MAKE_BINARY(DivNoNan,
|
||||
DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
|
||||
|
||||
// Implementation of FloorDiv. Pseudo-code:
|
||||
// Implementation of FloorDiv.
|
||||
//
|
||||
// For floating-point values, simply returns floor(x / y). For integers, does:
|
||||
//
|
||||
// if ((x < 0) != (y < 0)) {
|
||||
// T abs_x = std::abs(x);
|
||||
// T abs_y = std::abs(y);
|
||||
@ -90,6 +93,9 @@ XLA_MAKE_BINARY(DivNoNan,
|
||||
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||
xla::XlaOp y, const BCast& broadcast_helper) {
|
||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||
if (DataTypeIsFloating(dtype)) {
|
||||
return xla::Floor(xla::Div(x, y));
|
||||
}
|
||||
if (DataTypeIsUnsigned(dtype)) {
|
||||
return xla::Div(x, y);
|
||||
}
|
||||
@ -99,11 +105,7 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||
auto abs_x = xla::Abs(x);
|
||||
auto abs_y = xla::Abs(y);
|
||||
auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one));
|
||||
auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
|
||||
if (DataTypeIsFloating(dtype)) {
|
||||
result = xla::Floor(result);
|
||||
}
|
||||
return result;
|
||||
return xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
|
||||
}
|
||||
XLA_MAKE_BINARY(FloorDiv,
|
||||
FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
|
||||
|
@ -43,7 +43,9 @@ namespace {
|
||||
// A * H or H * A zeros out trailing part of some row or column of A.
|
||||
//
|
||||
// [x0, ..., x_{k-1}, xk, x_{k+1}, ..., x_{n-1}] * H
|
||||
// = [x0, ..., x_{k-1}, vnorm, 0, ..., 0]
|
||||
// = [x0, ..., x_{k-1}, xnorm, 0, ..., 0]
|
||||
//
|
||||
// Here xnorm = norm([x_k, x_{k+1}, ..., x_{n - 1}])
|
||||
struct HouseHolderResult {
|
||||
XlaOp v;
|
||||
XlaOp beta;
|
||||
@ -82,7 +84,7 @@ struct FrobeniusNorms {
|
||||
//
|
||||
// H = I - beta * [1, v]' * [1, v]
|
||||
//
|
||||
// H * x = [..., sigma, 0, ..., 0]
|
||||
// H * x = [..., xnorm, 0, ..., 0]
|
||||
// ..., j, j + 1, ..., n
|
||||
//
|
||||
// def house(x, j, eps):
|
||||
@ -161,21 +163,10 @@ StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
|
||||
HouseHolderResult result;
|
||||
result.v = v;
|
||||
result.beta = beta;
|
||||
a = Sub(a, Mul(beta, BatchDot(BatchDot(a, TransposeInMinorDims(v), precision),
|
||||
result.a =
|
||||
Sub(a, Mul(beta, BatchDot(BatchDot(a, TransposeInMinorDims(v), precision),
|
||||
v, precision)));
|
||||
|
||||
auto xnorm =
|
||||
Sqrt(Reduce(Square(Select(Ge(idx, j), x, zeros)), ScalarLike(x, 0.0),
|
||||
CreateScalarAddComputation(x_shape.element_type(), builder),
|
||||
{num_dims - 1}));
|
||||
|
||||
xnorm = BroadcastInDim(xnorm, x_shape.dimensions(), broadcast_dims);
|
||||
|
||||
x = Select(Lt(idx, j), x, zeros);
|
||||
x = Select(Eq(idx, j), xnorm, x);
|
||||
|
||||
result.a = DynamicUpdateSliceInMinorDims(a, x, {i, zero});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -184,7 +175,7 @@ StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
|
||||
//
|
||||
// H = I - beta * [1; v] * [1; v]', then,
|
||||
//
|
||||
// H * A[i:, j] = [sigma, 0, 0, ..., 0]
|
||||
// H * A[i:, j] = [xnorm, 0, 0, ..., 0]
|
||||
//
|
||||
StatusOr<HouseHolderResult> HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
|
||||
PrecisionConfig::Precision precision) {
|
||||
@ -239,21 +230,9 @@ StatusOr<HouseHolderResult> HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
|
||||
HouseHolderResult result;
|
||||
result.v = v;
|
||||
result.beta = beta;
|
||||
a = Sub(a,
|
||||
Mul(beta, BatchDot(v, BatchDot(TransposeInMinorDims(v), a, precision),
|
||||
precision)));
|
||||
|
||||
auto xnorm =
|
||||
Sqrt(Reduce(Square(Select(Ge(idx, i), x, zeros)), ScalarLike(x, 0.0),
|
||||
CreateScalarAddComputation(x_shape.element_type(), builder),
|
||||
{num_dims - 2}));
|
||||
|
||||
xnorm = BroadcastInDim(xnorm, x_shape.dimensions(), broadcast_dims);
|
||||
|
||||
x = Select(Lt(idx, i), x, zeros);
|
||||
x = Select(Eq(idx, i), xnorm, x);
|
||||
|
||||
result.a = DynamicUpdateSliceInMinorDims(a, x, {zero, j});
|
||||
result.a = Sub(
|
||||
a, Mul(beta, BatchDot(v, BatchDot(TransposeInMinorDims(v), a, precision),
|
||||
precision)));
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -774,7 +753,7 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
|
||||
auto zero = Zero(builder, S32);
|
||||
|
||||
// As m >= n, only first m columns vectors are needed to be permuted, and the
|
||||
// rest of n - m vectors are appended after the sorting is done.
|
||||
// rest of m - n vectors are appended after the sorting is done.
|
||||
XlaOp sort_u_result =
|
||||
Sort({-d, DynamicSliceInMinorDims(result.u, {zero, zero}, {m, n})},
|
||||
CreateScalarLtComputation(
|
||||
@ -799,7 +778,7 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
|
||||
{num_dims - 2})),
|
||||
broadcast_dims);
|
||||
|
||||
// Append the rest of n - m vectors.
|
||||
// Append the rest of m - n vectors.
|
||||
result.u =
|
||||
ConcatInDim(builder,
|
||||
{GetTupleElement(sort_u_result, 1),
|
||||
|
@ -131,6 +131,24 @@ class SVDTest : public ClientLibraryTestBase {
|
||||
Array3D<float> batch_3d_4x5_;
|
||||
};
|
||||
|
||||
XLA_TEST_F(SVDTest, Simple2D) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Array2D<float> simple_2d_4x4_ = Array2D<float>{
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
{8, 54, 146, 166},
|
||||
{10, 63, 166, 310},
|
||||
};
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(simple_2d_4x4_, 0, "a", &builder, &a);
|
||||
auto result = SVD(a, 100, 1e-6);
|
||||
ComputeMatmulUDVT(result, &builder);
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, simple_2d_4x4_, {a_data.get()},
|
||||
ErrorSpec(1e-3, 1e-3));
|
||||
}
|
||||
|
||||
XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_
|
||||
|
||||
#include "google/protobuf/duration.pb.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -44,6 +45,20 @@ Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
|
||||
// dirpath along as-is.
|
||||
void RegisterDirectoryExpander(const std::function<string(string)>& expander);
|
||||
|
||||
// Converts an absl::Duration to a google::protobuf::Duration.
|
||||
inline google::protobuf::Duration ToDurationProto(absl::Duration duration) {
|
||||
google::protobuf::Duration proto;
|
||||
proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration));
|
||||
proto.set_nanos(
|
||||
absl::IDivDuration(duration, absl::Nanoseconds(1), &duration));
|
||||
return proto;
|
||||
}
|
||||
|
||||
// Converts a google::protobuf::Duration to an absl::Duration.
|
||||
inline absl::Duration FromDurationProto(google::protobuf::Duration proto) {
|
||||
return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
|
||||
}
|
||||
|
||||
} // namespace protobuf_util
|
||||
} // namespace xla
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -440,14 +440,15 @@ cc_library(
|
||||
srcs = ["cudnn_conv_algorithm_picker.cc"],
|
||||
hdrs = ["cudnn_conv_algorithm_picker.h"],
|
||||
deps = [
|
||||
":autotuning_proto",
|
||||
":backend_configs",
|
||||
":buffer_comparator",
|
||||
":cudnn_conv_runner",
|
||||
":gpu_autotuning_proto",
|
||||
":gpu_executable",
|
||||
":ir_emission_utils",
|
||||
":scratch_allocator",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -455,9 +456,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:logger",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/util/proto:proto_utils",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/time",
|
||||
@ -777,7 +776,6 @@ cc_library(
|
||||
hdrs = ["gpu_transfer_manager.h"],
|
||||
deps = [
|
||||
":gpu_compiler",
|
||||
":infeed_manager",
|
||||
":outfeed_manager",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -790,6 +788,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:generic_transfer_manager",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -1138,8 +1137,8 @@ tf_cc_test(
|
||||
srcs = ["cudnn_fused_conv_rewriter_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:test",
|
||||
@ -1184,11 +1183,10 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
xla_proto_library(
|
||||
name = "gpu_autotuning_proto",
|
||||
srcs = ["gpu_autotuning.proto"],
|
||||
name = "autotuning_proto",
|
||||
srcs = ["autotuning.proto"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
@ -1,14 +1,15 @@
|
||||
// This file defines protos that store the results of autotuning various
|
||||
// This file defines protos that store the results of autotuning XLA:GPU
|
||||
// operations.
|
||||
//
|
||||
// They are in proto format because we want to log them structured. They offer
|
||||
// tremendous statistical, testing, and debugging value.
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
package xla.gpu;
|
||||
|
||||
import "google/protobuf/any.proto";
|
||||
import "google/protobuf/duration.proto";
|
||||
import "tensorflow/compiler/xla/xla_data.proto";
|
||||
import "tensorflow/compiler/xla/service/hlo.proto";
|
||||
|
||||
message CudnnVersion {
|
||||
int32 major = 1;
|
||||
@ -62,12 +63,19 @@ message AutotuneResult {
|
||||
}
|
||||
}
|
||||
|
||||
message AutotuningLog {
|
||||
google.protobuf.Any instr = 1;
|
||||
message AutotuneLog {
|
||||
message Instruction {
|
||||
xla.HloInstructionProto instruction = 1;
|
||||
repeated xla.ShapeProto operand_shapes = 2;
|
||||
}
|
||||
|
||||
oneof instr_oneof {
|
||||
Instruction instr = 1;
|
||||
}
|
||||
|
||||
// Records all auto-tuning results per algorithm.
|
||||
repeated AutotuneResult results = 2;
|
||||
repeated AutotuneResult results = 3;
|
||||
|
||||
CudnnVersion cudnn_version = 3;
|
||||
ComputeCapability compute_capability = 4;
|
||||
CudnnVersion cudnn_version = 4;
|
||||
ComputeCapability compute_capability = 5;
|
||||
}
|
@ -14,23 +14,21 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -39,7 +37,6 @@ namespace {
|
||||
using absl::optional;
|
||||
using se::DeviceMemoryBase;
|
||||
using se::dnn::AlgorithmDesc;
|
||||
using tensorflow::AutotuneResult;
|
||||
|
||||
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
@ -97,8 +94,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
|
||||
return tensorflow::mutex_lock{it->second};
|
||||
}
|
||||
|
||||
tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
|
||||
tensorflow::CudnnVersion cudnn_version;
|
||||
xla::gpu::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
|
||||
xla::gpu::CudnnVersion cudnn_version;
|
||||
if (auto* dnn = stream_executor->AsDnn()) {
|
||||
StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
|
||||
if (version_or.ok()) {
|
||||
@ -111,9 +108,9 @@ tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
|
||||
return cudnn_version;
|
||||
}
|
||||
|
||||
tensorflow::ComputeCapability GetComputeCapability(
|
||||
xla::gpu::ComputeCapability GetComputeCapability(
|
||||
se::StreamExecutor* stream_executor) {
|
||||
tensorflow::ComputeCapability cc;
|
||||
xla::gpu::ComputeCapability cc;
|
||||
int cc_major, cc_minor;
|
||||
stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
@ -246,23 +243,25 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
|
||||
RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
|
||||
&scratch_allocator, &stream, options);
|
||||
|
||||
if (!launch_status.ok()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!profile_result.is_valid()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
profile_results.emplace_back();
|
||||
AutotuneResult& result = profile_results.back();
|
||||
result.mutable_conv()->set_algorithm(alg.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
|
||||
|
||||
if (!launch_status.ok()) {
|
||||
result.set_error_string(launch_status.error_message());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!profile_result.is_valid()) {
|
||||
result.set_error_string("Invalid profile result");
|
||||
continue;
|
||||
}
|
||||
|
||||
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
|
||||
result.mutable_success()->set_scratch_bytes(scratch_bytes_used);
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
tensorflow::proto_utils::ToDurationProto(
|
||||
protobuf_util::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
|
||||
const bool crash_on_checking_failure =
|
||||
@ -309,14 +308,10 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
|
||||
|
||||
// Log the autotuning result.
|
||||
{
|
||||
tensorflow::AutotuningLog log;
|
||||
{
|
||||
ConvInstructionLog instr_log;
|
||||
*instr_log.mutable_instruction() = instr->ToProto();
|
||||
for (const auto* op : instr->operands()) {
|
||||
*instr_log.add_operand_shapes() = op->shape().ToProto();
|
||||
}
|
||||
log.mutable_instr()->PackFrom(instr_log);
|
||||
AutotuneLog log;
|
||||
*log.mutable_instr()->mutable_instruction() = instr->ToProto();
|
||||
for (const auto* op : instr->operands()) {
|
||||
*log.mutable_instr()->add_operand_shapes() = op->shape().ToProto();
|
||||
}
|
||||
for (const auto& profile : profile_results) {
|
||||
*log.add_results() = profile;
|
||||
@ -335,14 +330,13 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
|
||||
// The successful one should have a smaller key, since we are doing
|
||||
// min_element. If they are both unsuccessful, keep the earlier one in
|
||||
// the vector by comparing pointers.
|
||||
return std::make_tuple(!lhs.has_success(),
|
||||
tensorflow::proto_utils::FromDurationProto(
|
||||
lhs.success().run_time()),
|
||||
&lhs) <
|
||||
std::make_tuple(!rhs.has_success(),
|
||||
tensorflow::proto_utils::FromDurationProto(
|
||||
rhs.success().run_time()),
|
||||
&rhs);
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(),
|
||||
protobuf_util::FromDurationProto(lhs.success().run_time()),
|
||||
&lhs) < std::make_tuple(!rhs.has_success(),
|
||||
protobuf_util::FromDurationProto(
|
||||
rhs.success().run_time()),
|
||||
&rhs);
|
||||
});
|
||||
|
||||
if (best_result != profile_results_end && best_result->has_success()) {
|
||||
|
@ -20,12 +20,12 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/autotuning.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -50,7 +50,7 @@ class CudnnConvAlgorithmPicker : public HloModulePass {
|
||||
private:
|
||||
StatusOr<bool> RunOnComputation(HloComputation* computation);
|
||||
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
|
||||
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithm(
|
||||
StatusOr<AutotuneResult> PickBestAlgorithm(
|
||||
const HloCustomCallInstruction* instr);
|
||||
|
||||
se::StreamExecutor* stream_exec_; // never null
|
||||
|
@ -1,13 +0,0 @@
|
||||
// This is used for convolution logging. Also see
|
||||
// tensorflow/core/protobuf/autotuing.h
|
||||
syntax = "proto3";
|
||||
|
||||
package xla.gpu;
|
||||
|
||||
import "tensorflow/compiler/xla/service/hlo.proto";
|
||||
import "tensorflow/compiler/xla/xla_data.proto";
|
||||
|
||||
message ConvInstructionLog {
|
||||
xla.HloInstructionProto instruction = 1;
|
||||
repeated xla.ShapeProto operand_shapes = 2;
|
||||
}
|
@ -1742,7 +1742,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
|
||||
dnums.DebugString());
|
||||
}
|
||||
|
||||
if (kernel_output_features % feature_group_count > 0) {
|
||||
// A depthwise/grouped filter has the shape
|
||||
// [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
|
||||
// [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
|
||||
// [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
|
||||
// feature count (which is equal to kernel output features) has to be a
|
||||
// multiple of feature_group_count.
|
||||
return InvalidArgument(
|
||||
"Expected output feature dimension (value %d) to be divisible by "
|
||||
"feature_group_count (value %d); "
|
||||
|
@ -629,6 +629,9 @@ BENCHMARK_NAME := $(BINDIR)benchmark
|
||||
|
||||
CORE_CC_ALL_SRCS := \
|
||||
$(ABSL_CC_SRCS) \
|
||||
tensorflow/c/c_api.cc \
|
||||
tensorflow/c/kernels.cc \
|
||||
tensorflow/c/tf_status_helper.cc \
|
||||
$(wildcard tensorflow/core/*.cc) \
|
||||
$(wildcard tensorflow/core/common_runtime/*.cc) \
|
||||
$(wildcard tensorflow/core/framework/*.cc) \
|
||||
|
@ -13,11 +13,9 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for contrib.seq2seq.python.ops.attention_wrapper."""
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import collections
|
||||
import functools
|
||||
@ -30,6 +28,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -66,6 +65,7 @@ def get_result_summary(x):
|
||||
return x
|
||||
|
||||
|
||||
@test_util.run_v1_only
|
||||
class AttentionWrapperTest(test.TestCase):
|
||||
|
||||
def assertAllCloseOrEqual(self, x, y, **kwargs):
|
||||
|
@ -30,7 +30,6 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
@ -305,7 +304,10 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
attention_layer_size = attention_layer_size[0]
|
||||
if attention_layer is not None:
|
||||
attention_layer = attention_layer[0]
|
||||
cell = rnn_cell.LSTMCell(cell_depth, initializer="ones")
|
||||
cell = keras.layers.LSTMCell(cell_depth,
|
||||
recurrent_activation="sigmoid",
|
||||
kernel_initializer="ones",
|
||||
recurrent_initializer="ones")
|
||||
cell = wrapper.AttentionWrapper(
|
||||
cell,
|
||||
attention_mechanisms if is_multi else attention_mechanisms[0],
|
||||
@ -321,7 +323,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
sampler = sampler_py.TrainingSampler()
|
||||
my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)
|
||||
initial_state = cell.zero_state(
|
||||
initial_state = cell.get_initial_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size)
|
||||
final_outputs, final_state, _ = my_decoder(
|
||||
decoder_inputs,
|
||||
@ -330,7 +332,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
|
||||
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
|
||||
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
|
||||
|
||||
expected_time = (
|
||||
expected_final_state.time if context.executing_eagerly() else None)
|
||||
@ -342,9 +343,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual((batch_size, attention_depth),
|
||||
tuple(final_state.attention.get_shape().as_list()))
|
||||
self.assertEqual((batch_size, cell_depth),
|
||||
tuple(final_state.cell_state.c.get_shape().as_list()))
|
||||
tuple(final_state.cell_state[0].get_shape().as_list()))
|
||||
self.assertEqual((batch_size, cell_depth),
|
||||
tuple(final_state.cell_state.h.get_shape().as_list()))
|
||||
tuple(final_state.cell_state[1].get_shape().as_list()))
|
||||
|
||||
if alignment_history:
|
||||
if is_multi:
|
||||
@ -395,8 +396,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
expected_final_alignment_history,
|
||||
final_alignment_history_info)
|
||||
|
||||
@parameterized.parameters([np.float16, np.float32, np.float64])
|
||||
def _testBahdanauNormalizedDType(self, dtype):
|
||||
# TODO(b/126893309): reenable np.float16 once the bug is fixed.
|
||||
@parameterized.parameters([np.float32, np.float64])
|
||||
def testBahdanauNormalizedDType(self, dtype):
|
||||
encoder_outputs = self.encoder_outputs.astype(dtype)
|
||||
decoder_inputs = self.decoder_inputs.astype(dtype)
|
||||
attention_mechanism = wrapper.BahdanauAttentionV2(
|
||||
@ -405,7 +407,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
memory_sequence_length=self.encoder_sequence_length,
|
||||
normalize=True,
|
||||
dtype=dtype)
|
||||
cell = rnn_cell.LSTMCell(self.units)
|
||||
cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
|
||||
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
|
||||
|
||||
sampler = sampler_py.TrainingSampler()
|
||||
@ -418,9 +420,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
|
||||
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
|
||||
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
|
||||
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
|
||||
|
||||
@parameterized.parameters([np.float16, np.float32, np.float64])
|
||||
# TODO(b/126893309): reenable np.float16 once the bug is fixed.
|
||||
@parameterized.parameters([np.float32, np.float64])
|
||||
def testLuongScaledDType(self, dtype):
|
||||
# Test case for GitHub issue 18099
|
||||
encoder_outputs = self.encoder_outputs.astype(dtype)
|
||||
@ -432,7 +434,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
scale=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
cell = rnn_cell.LSTMCell(self.units)
|
||||
cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
|
||||
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
|
||||
|
||||
sampler = sampler_py.TrainingSampler()
|
||||
@ -445,7 +447,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
|
||||
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
|
||||
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
|
||||
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
|
||||
|
||||
def testBahdanauNotNormalized(self):
|
||||
create_attention_mechanism = wrapper.BahdanauAttentionV2
|
||||
@ -455,11 +456,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324),
|
||||
sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569),
|
||||
time=3,
|
||||
@ -490,11 +491,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728),
|
||||
time=3,
|
||||
@ -520,11 +521,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=4.084631),
|
||||
time=3,
|
||||
@ -550,11 +551,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314),
|
||||
time=3,
|
||||
@ -581,11 +582,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335),
|
||||
time=3,
|
||||
@ -613,11 +614,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186),
|
||||
time=3,
|
||||
@ -648,11 +649,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721),
|
||||
time=3,
|
||||
@ -682,11 +683,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
|
||||
time=3,
|
||||
@ -716,11 +717,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
|
||||
sample_id=ResultSummary(
|
||||
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
cell_state=rnn_cell.LSTMStateTuple(
|
||||
c=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
|
||||
h=ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
|
||||
cell_state=[
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038),
|
||||
ResultSummary(
|
||||
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)],
|
||||
attention=ResultSummary(
|
||||
shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
|
||||
time=3,
|
||||
|
@ -13,31 +13,30 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder."""
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core as layers_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
@test_util.run_v1_only
|
||||
class BasicDecoderTest(test.TestCase):
|
||||
|
||||
def _testStepWithTrainingHelper(self, use_output_layer):
|
||||
|
@ -187,14 +187,23 @@ class TestArrayShapeChecks(test.TestCase):
|
||||
shape=dynamic_shape)
|
||||
|
||||
batch_size = array_ops.constant(batch_size)
|
||||
check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access
|
||||
|
||||
with self.cached_session() as sess:
|
||||
if is_valid:
|
||||
sess.run(check_op)
|
||||
def _test_body():
|
||||
# pylint: disable=protected-access
|
||||
if context.executing_eagerly():
|
||||
beam_search_decoder._check_batch_beam(t, batch_size, beam_width)
|
||||
else:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(check_op)
|
||||
with self.cached_session():
|
||||
check_op = beam_search_decoder._check_batch_beam(
|
||||
t, batch_size, beam_width)
|
||||
self.evaluate(check_op)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if is_valid:
|
||||
_test_body()
|
||||
else:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
_test_body()
|
||||
|
||||
def test_array_shape_dynamic_checks(self):
|
||||
self._test_array_shape_dynamic_checks(
|
||||
@ -463,6 +472,7 @@ class TestLargeBeamStep(test.TestCase):
|
||||
self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
|
||||
|
||||
|
||||
@test_util.run_v1_only
|
||||
class BeamSearchDecoderTest(test.TestCase):
|
||||
|
||||
def _testDynamicDecodeRNN(self, time_major, has_attention,
|
||||
|
@ -49,8 +49,8 @@ class GatherTreeTest(test.TestCase):
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
with self.cached_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, self.evaluate(beams))
|
||||
|
||||
def testBadParentValuesOnCPU(self):
|
||||
# (batch_size = 1, max_time = 4, beams = 3)
|
||||
@ -62,15 +62,14 @@ class GatherTreeTest(test.TestCase):
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
max_sequence_lengths = [3]
|
||||
with ops.device("/cpu:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError(
|
||||
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
|
||||
_ = beams.eval()
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
self.evaluate(beams)
|
||||
|
||||
def testBadParentValuesOnGPU(self):
|
||||
# Only want to run this test on CUDA devices, as gather_tree is not
|
||||
@ -93,8 +92,7 @@ class GatherTreeTest(test.TestCase):
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=end_token)
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
self.assertAllEqual(expected_result, self.evaluate(beams))
|
||||
|
||||
def testGatherTreeBatch(self):
|
||||
batch_size = 10
|
||||
@ -103,7 +101,7 @@ class GatherTreeTest(test.TestCase):
|
||||
max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
|
||||
end_token = 5
|
||||
|
||||
with self.session(use_gpu=True):
|
||||
with self.cached_session(use_gpu=True):
|
||||
step_ids = np.random.randint(
|
||||
0, high=end_token + 1, size=(max_time, batch_size, beam_width))
|
||||
parent_ids = np.random.randint(
|
||||
@ -116,7 +114,7 @@ class GatherTreeTest(test.TestCase):
|
||||
end_token=end_token)
|
||||
|
||||
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
|
||||
beams_value = beams.eval()
|
||||
beams_value = self.evaluate(beams)
|
||||
for b in range(batch_size):
|
||||
# Past max_sequence_lengths[b], we emit all end tokens.
|
||||
b_value = beams_value[max_sequence_lengths[b]:, b, :]
|
||||
|
@ -13,26 +13,25 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for contrib.seq2seq.python.seq2seq.decoder."""
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
|
||||
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import test
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
@test_util.run_v1_only
|
||||
class DynamicDecodeRNNTest(test.TestCase):
|
||||
|
||||
def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):
|
||||
|
@ -31,7 +31,7 @@ from tensorflow.python.platform import test
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LossTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
def config_default_values(self):
|
||||
self.batch_size = 2
|
||||
self.sequence_length = 3
|
||||
self.number_of_classes = 5
|
||||
@ -56,7 +56,8 @@ class LossTest(test.TestCase):
|
||||
self.expected_loss = 1.60944
|
||||
|
||||
def testSequenceLoss(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
self.config_default_values()
|
||||
with self.cached_session(use_gpu=True):
|
||||
average_loss_per_example = loss.sequence_loss(
|
||||
self.logits, self.targets, self.weights,
|
||||
average_across_timesteps=True,
|
||||
@ -90,7 +91,8 @@ class LossTest(test.TestCase):
|
||||
self.assertAllClose(compare_total, res)
|
||||
|
||||
def testSequenceLossClass(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
self.config_default_values()
|
||||
with self.cached_session(use_gpu=True):
|
||||
seq_loss = loss.SequenceLoss(average_across_timesteps=True,
|
||||
average_across_batch=True,
|
||||
sum_over_timesteps=False,
|
||||
@ -132,7 +134,8 @@ class LossTest(test.TestCase):
|
||||
self.assertAllClose(compare_total, res)
|
||||
|
||||
def testSumReduction(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
self.config_default_values()
|
||||
with self.cached_session(use_gpu=True):
|
||||
seq_loss = loss.SequenceLoss(average_across_timesteps=False,
|
||||
average_across_batch=False,
|
||||
sum_over_timesteps=True,
|
||||
@ -174,6 +177,7 @@ class LossTest(test.TestCase):
|
||||
self.assertAllClose(compare_total, res)
|
||||
|
||||
def testWeightedSumReduction(self):
|
||||
self.config_default_values()
|
||||
weights = [
|
||||
constant_op.constant(1.0, shape=[self.batch_size])
|
||||
for _ in range(self.sequence_length)
|
||||
@ -181,7 +185,7 @@ class LossTest(test.TestCase):
|
||||
# Make the last element in the sequence to have zero weights.
|
||||
weights[-1] = constant_op.constant(0.0, shape=[self.batch_size])
|
||||
self.weights = array_ops.stack(weights, axis=1)
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.cached_session(use_gpu=True):
|
||||
seq_loss = loss.SequenceLoss(average_across_timesteps=False,
|
||||
average_across_batch=False,
|
||||
sum_over_timesteps=True,
|
||||
@ -225,12 +229,13 @@ class LossTest(test.TestCase):
|
||||
self.assertAllClose(compare_total, res)
|
||||
|
||||
def testZeroWeights(self):
|
||||
self.config_default_values()
|
||||
weights = [
|
||||
constant_op.constant(0.0, shape=[self.batch_size])
|
||||
for _ in range(self.sequence_length)
|
||||
]
|
||||
weights = array_ops.stack(weights, axis=1)
|
||||
with self.test_session(use_gpu=True):
|
||||
with self.cached_session(use_gpu=True):
|
||||
average_loss_per_example = loss.sequence_loss(
|
||||
self.logits, self.targets, weights,
|
||||
average_across_timesteps=True,
|
||||
|
@ -2347,7 +2347,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
|
||||
if self._initial_cell_state is not None:
|
||||
cell_state = self._initial_cell_state
|
||||
else:
|
||||
cell_state = self._cell.zero_state(batch_size, dtype)
|
||||
cell_state = self._cell.get_initial_state(batch_size=batch_size,
|
||||
dtype=dtype)
|
||||
error_message = (
|
||||
"When calling zero_state of AttentionWrapper %s: " % self._base_name +
|
||||
"Non-matching batch sizes between the memory "
|
||||
|
@ -218,7 +218,7 @@ def _check_batch_beam(t, batch_size, beam_width):
|
||||
"incompatible with the dynamic shape of %s elements. "
|
||||
"Consider setting reorder_tensor_arrays to False to disable "
|
||||
"TensorArray reordering during the beam search."
|
||||
% (t.name))
|
||||
% (t if context.executing_eagerly() else t.name))
|
||||
rank = t.shape.ndims
|
||||
shape = array_ops.shape(t)
|
||||
if rank == 2:
|
||||
|
@ -233,7 +233,6 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS
|
||||
ADDITIONAL_CORE_PROTO_SRCS = [
|
||||
"example/example_parser_configuration.proto",
|
||||
"protobuf/trackable_object_graph.proto",
|
||||
"protobuf/autotuning.proto",
|
||||
"protobuf/control_flow.proto",
|
||||
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
|
||||
# "protobuf/critical_section.proto",
|
||||
@ -926,6 +925,7 @@ tf_cuda_library(
|
||||
"framework/tensor_slice.h",
|
||||
"framework/tensor_types.h",
|
||||
"framework/tensor_util.h",
|
||||
"framework/thread_factory.h",
|
||||
"framework/tracking_allocator.h",
|
||||
"framework/type_index.h",
|
||||
"framework/type_traits.h",
|
||||
@ -1671,6 +1671,7 @@ filegroup(
|
||||
":protos_all_proto_text_srcs",
|
||||
":error_codes_proto_text_srcs",
|
||||
"//tensorflow/core/platform/default/build_config:android_srcs",
|
||||
"//tensorflow/c:srcs",
|
||||
] + glob(
|
||||
[
|
||||
"client/**/*.cc",
|
||||
|
@ -0,0 +1,32 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalAutoShardDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_workers"
|
||||
description: <<END
|
||||
A scalar representing the number of workers to distribute this dataset across.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "index"
|
||||
description: <<END
|
||||
A scalar representing the index of the current worker out of num_workers.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that shards the input dataset."
|
||||
description: <<END
|
||||
Creates a dataset that shards the input dataset by num_workers, returning a
|
||||
sharded dataset for the index-th worker. This attempts to automatically shard
|
||||
a dataset by examining the Dataset graph and inserting a shard op before the
|
||||
inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset).
|
||||
|
||||
This dataset will throw a NotFound error if we cannot shard the dataset
|
||||
automatically.
|
||||
END
|
||||
}
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/thread_factory.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
@ -287,7 +289,8 @@ class IteratorContext {
|
||||
model(ctx->model()),
|
||||
runner(*(ctx->runner())),
|
||||
runner_threadpool_size(ctx->runner_threadpool_size()),
|
||||
stats_aggregator(ctx->stats_aggregator()) {}
|
||||
stats_aggregator(ctx->stats_aggregator()),
|
||||
thread_factory(ctx->thread_factory()) {}
|
||||
|
||||
explicit Params(OpKernelContext* ctx)
|
||||
: env(ctx->env()),
|
||||
@ -338,6 +341,10 @@ class IteratorContext {
|
||||
|
||||
// The `StatsAggregator` object to record statistics about the iterator.
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
|
||||
|
||||
// A `ThreadFactory` for creating threads used by iterators to perform
|
||||
// blocking work.
|
||||
std::shared_ptr<ThreadFactory> thread_factory = nullptr;
|
||||
};
|
||||
|
||||
explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
|
||||
@ -374,6 +381,20 @@ class IteratorContext {
|
||||
return ¶ms_.runner;
|
||||
}
|
||||
|
||||
const std::shared_ptr<ThreadFactory>& thread_factory() {
|
||||
return params_.thread_factory;
|
||||
}
|
||||
|
||||
std::unique_ptr<Thread> StartThread(const string& name,
|
||||
std::function<void()> fn) {
|
||||
if (params_.thread_factory) {
|
||||
return params_.thread_factory->StartThread(name, std::move(fn));
|
||||
} else {
|
||||
return absl::WrapUnique(
|
||||
Env::Default()->StartThread({}, name, std::move(fn)));
|
||||
}
|
||||
}
|
||||
|
||||
int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
|
||||
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator() {
|
||||
|
42
tensorflow/core/framework/thread_factory.h
Normal file
42
tensorflow/core/framework/thread_factory.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Thread;
|
||||
|
||||
// Virtual interface for an object that creates threads.
|
||||
class ThreadFactory {
|
||||
public:
|
||||
virtual ~ThreadFactory() {}
|
||||
|
||||
// Runs `fn` asynchronously in a different thread. `fn` may block.
|
||||
//
|
||||
// NOTE: The caller is responsible for ensuring that this `ThreadFactory`
|
||||
// outlives the returned `Thread`.
|
||||
virtual std::unique_ptr<Thread> StartThread(const string& name,
|
||||
std::function<void()> fn) = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_
|
@ -979,6 +979,41 @@ class SymbolicShapeRefiner {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Return true if the annotated shape is compatible with shape inference
|
||||
// result. Examples:
|
||||
// Inferred shape: ?, annotated shape: [10, 10] -> true;
|
||||
// Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
|
||||
// Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
|
||||
// Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
|
||||
bool CompatibleShapes(ShapeHandle inferred_shape,
|
||||
ShapeHandle annotated_shape) const {
|
||||
if (inferred_shape.SameHandle(annotated_shape)) {
|
||||
return true;
|
||||
}
|
||||
if (!InferenceContext::RankKnown(inferred_shape)) {
|
||||
return true;
|
||||
}
|
||||
if (InferenceContext::Rank(inferred_shape) !=
|
||||
InferenceContext::Rank(annotated_shape)) {
|
||||
return false;
|
||||
}
|
||||
const int rank = InferenceContext::Rank(inferred_shape);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (!InferenceContext::DimKnownRank(inferred_shape, i)
|
||||
.SameHandle(
|
||||
InferenceContext::DimKnownRank(annotated_shape, i))) {
|
||||
int64 val1 = InferenceContext::Value(
|
||||
InferenceContext::DimKnownRank(inferred_shape, i));
|
||||
int64 val2 = InferenceContext::Value(
|
||||
InferenceContext::DimKnownRank(annotated_shape, i));
|
||||
if (val1 >= 0 && val1 != val2) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
|
||||
const std::vector<ShapeAndType>& st2) const {
|
||||
if (st1.size() != st2.size()) {
|
||||
@ -1139,9 +1174,9 @@ class SymbolicShapeRefiner {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if we want to update output values with running EvaluateNode()
|
||||
// for this op, based on op type, data type, and size.
|
||||
bool ShouldUpdateOutputValues(NodeContext* c, int64 max_size) {
|
||||
// Returns true if we want to update output shapes and values with running
|
||||
// EvaluateNode() for this op, based on op type, data type, and size.
|
||||
bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
|
||||
InferenceContext* ic = c->inference_context.get();
|
||||
|
||||
// Due to the cost of running EvaluateNode(), we limit only to white listed
|
||||
@ -1232,8 +1267,9 @@ class SymbolicShapeRefiner {
|
||||
}
|
||||
}
|
||||
|
||||
// Run a node to infer output values, and add it to the NodeContext.
|
||||
Status UpdateOutputValues(const NodeDef& node, NodeContext* c) {
|
||||
// Run a node to infer output shapes and values, and add it to the
|
||||
// NodeContext.
|
||||
Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
|
||||
InferenceContext* ic = c->inference_context.get();
|
||||
|
||||
// Input to EvaluateNode()
|
||||
@ -1264,7 +1300,7 @@ class SymbolicShapeRefiner {
|
||||
ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
|
||||
if (ic->FullyDefined(ic->output(k)) &&
|
||||
!EquivalentShapes(ic->output(k), output_shape)) {
|
||||
LOG(WARNING) << "UpdateOutputValues() -- node: " << node.name()
|
||||
LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
|
||||
<< ", inferred output shape "
|
||||
<< "doesn't match for k=" << k << ": "
|
||||
<< "ic->output(k): " << ic->DebugString(ic->output(k))
|
||||
@ -1284,6 +1320,54 @@ class SymbolicShapeRefiner {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Update output shapes with annotated information.
|
||||
// Currently only handle nodes with static shapes, i.e. shapes do not change
|
||||
// during execution.
|
||||
// TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
|
||||
Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
|
||||
NodeContext* c) const {
|
||||
const auto& attr = node.attr();
|
||||
if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
|
||||
attr.count(kOutputShapes) == 0)
|
||||
return Status::OK();
|
||||
|
||||
InferenceContext* ic = c->inference_context.get();
|
||||
int output_size = attr.at(kOutputShapes).list().shape_size();
|
||||
|
||||
for (int i = 0; i < ic->num_outputs(); i++) {
|
||||
// Annotated Switch node has only one output. Propagate the shape to all
|
||||
// the outputs.
|
||||
int shape_index = IsSwitch(node) ? 0 : i;
|
||||
if (shape_index >= output_size) {
|
||||
LOG(WARNING)
|
||||
<< "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
|
||||
<< node.name() << ", inferred output shape size "
|
||||
<< ic->num_outputs() << ", annotated output shape size "
|
||||
<< output_size;
|
||||
break;
|
||||
}
|
||||
|
||||
const TensorShapeProto& shape =
|
||||
attr.at(kOutputShapes).list().shape(shape_index);
|
||||
ShapeHandle output_shape;
|
||||
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
|
||||
|
||||
// Only use annotated shapes if the inference shape is unknown and
|
||||
// compatible with annotated shapes.
|
||||
if (!ic->FullyDefined(ic->output(i)) &&
|
||||
CompatibleShapes(ic->output(i), output_shape)) {
|
||||
VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
|
||||
<< node.name() << ", inferred output shape " << i << ": "
|
||||
<< "ic->output(i): " << ic->DebugString(ic->output(i))
|
||||
<< ", annotated output shape: " << ic->DebugString(output_shape)
|
||||
<< " -- " << node.ShortDebugString();
|
||||
ic->set_output(i, output_shape);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
|
||||
NodeContext* c) {
|
||||
// Propagate tensors and shape tensors unless the node is fed.
|
||||
@ -1476,16 +1560,19 @@ class SymbolicShapeRefiner {
|
||||
}
|
||||
|
||||
if (aggressive_shape_inference_) {
|
||||
// Update output shapes with annotated information. This is optional.
|
||||
UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
|
||||
|
||||
// Update output tensor values using EvaluateNode() if we can.
|
||||
// Due to the cost of EvaluateNode(), we run it only for certain op types
|
||||
// (white listed) and small integer tensors.
|
||||
|
||||
const int max_element_size = 17; // Max up to 4x4 matrix or similar.
|
||||
if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
|
||||
!ShouldUpdateOutputValues(c, max_element_size)) {
|
||||
!ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
|
||||
return Status::OK();
|
||||
}
|
||||
UpdateOutputValues(node, c).IgnoreError(); // This is optional.
|
||||
UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1797,6 +1884,7 @@ Status GraphProperties::UpdateShapes(
|
||||
// UpdateNode calls UpdateFunction if a function node is detected.
|
||||
TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,45 @@ namespace tensorflow {
|
||||
|
||||
namespace grappler {
|
||||
|
||||
// Optional attributes that tell about node output information.
|
||||
// We use these side information, if provided, for static shape inference
|
||||
// and VirtualScheduler scheduling.
|
||||
|
||||
// Switch op attribute as a vector of int that tells which branch the
|
||||
// Switch output is taken on every round of execution.
|
||||
// Used for scheduling ops after Switch correctly (e.g., While loop).
|
||||
ABSL_CONST_INIT const char kOutputSlots[] = "_output_slot_vector";
|
||||
|
||||
// Example:
|
||||
// Assume a node has two outputs and iterated for three times. Then it has:
|
||||
// _execution_count = 3
|
||||
// _output_sizes_vector = [2, 2, 2]
|
||||
// _output_dtype_vector.size = 6
|
||||
// _output_shape_vector.size = 6
|
||||
|
||||
// If all the iterations have same output shapes, then
|
||||
// _execution_count = 3
|
||||
// _same_output_for_iterations = true
|
||||
// _output_sizes_vector = [2]
|
||||
// _output_dtype_vector.size = 2
|
||||
// _output_shape_vector.size = 2
|
||||
|
||||
// How many times this node has been executed.
|
||||
ABSL_CONST_INIT const char kExecutionCount[] = "_execution_count";
|
||||
|
||||
// Records the output sizes for each round of execution.
|
||||
ABSL_CONST_INIT const char kOutputSizes[] = "_output_sizes_vector";
|
||||
|
||||
// The node has been scheduled multiple times with outputs that have the same
|
||||
// shape.
|
||||
ABSL_CONST_INIT const char kOutputSame[] = "_same_output_for_iterations";
|
||||
|
||||
// Outputs DataType vector.
|
||||
ABSL_CONST_INIT const char kOutputTypes[] = "_output_dtype_vector";
|
||||
|
||||
// Outputs TensorShapeProto vector.
|
||||
ABSL_CONST_INIT const char kOutputShapes[] = "_output_shape_vector";
|
||||
|
||||
class SymbolicShapeRefiner;
|
||||
class TopoQueue;
|
||||
|
||||
|
@ -1793,6 +1793,103 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
|
||||
ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, ShapeAnnotation) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("shape", PartialTensorShape({-1, -1}))
|
||||
.Finalize(item.graph.add_node()));
|
||||
// Annotate shapes.
|
||||
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("_same_output_for_iterations", true)
|
||||
.Attr("_output_shape_vector", {TensorShape({5, 7})})
|
||||
.Input("Input", 0, DT_FLOAT)
|
||||
.Finalize(item.graph.add_node()));
|
||||
{
|
||||
GraphProperties properties(item);
|
||||
// Without aggressive_shape_inference, ignore annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/false));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
// Get unknown shapes without using annotated information.
|
||||
EXPECT_EQ("float: [-1,-1]", PropToString(prop));
|
||||
}
|
||||
{
|
||||
GraphProperties properties(item);
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
// Update output shape using annotated shapes.
|
||||
EXPECT_EQ("float: [5,7]", PropToString(prop));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("shape", PartialTensorShape({-1, 100}))
|
||||
.Finalize(item.graph.add_node()));
|
||||
// Annotate shapes.
|
||||
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("_same_output_for_iterations", true)
|
||||
.Attr("_output_shape_vector", {TensorShape({10, 100})})
|
||||
.Input("Input", 0, DT_FLOAT)
|
||||
.Finalize(item.graph.add_node()));
|
||||
GraphProperties properties(item);
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
// Compatible shapes. Update output shape using annotated shapes.
|
||||
EXPECT_EQ("float: [10,100]", PropToString(prop));
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("shape", PartialTensorShape({-1, 100}))
|
||||
.Finalize(item.graph.add_node()));
|
||||
// Annotate shapes.
|
||||
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("_same_output_for_iterations", true)
|
||||
.Attr("_output_shape_vector", {TensorShape({10, 10})})
|
||||
.Input("Input", 0, DT_FLOAT)
|
||||
.Finalize(item.graph.add_node()));
|
||||
GraphProperties properties(item);
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop.dtype());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
// Incompatible shapes. Do not use annotated shapes.
|
||||
EXPECT_EQ("float: [-1,100]", PropToString(prop));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -36,12 +36,6 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
// Optional attribute name for Switch op as a vector of int that tells
|
||||
// which branch the Switch output is taken on every round of execution.
|
||||
// We use this side information, if provided, for scheduling ops after Switch
|
||||
// correctly (e.g., While loop).
|
||||
constexpr char kOutputSlots[] = "_output_slot_vector";
|
||||
|
||||
Costs CombineCosts(const Costs& left, const Costs& right) {
|
||||
CHECK_NE(left.max_memory, kMemoryUnknown);
|
||||
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);
|
||||
|
@ -29,6 +29,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "auto_shard",
|
||||
srcs = ["auto_shard.cc"],
|
||||
hdrs = ["auto_shard.h"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
":optimizer_base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
] + tf_protos_all(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "filter_fusion",
|
||||
srcs = ["filter_fusion.cc"],
|
||||
|
300
tensorflow/core/grappler/optimizers/data/auto_shard.cc
Normal file
300
tensorflow/core/grappler/optimizers/data/auto_shard.cc
Normal file
@ -0,0 +1,300 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/grappler/utils/functions.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
// clang-format off
|
||||
constexpr char kShardDatasetOpName[] = "ShardDataset";
|
||||
constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
|
||||
|
||||
constexpr std::array<const char*, 4> kReaderDatasetOps = {
|
||||
"FixedLengthRecordDataset",
|
||||
"FixedLengthRecordDatasetV2",
|
||||
"TextLineDataset",
|
||||
"TFRecordDataset"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||
"ConcatenateDataset",
|
||||
"ZipDataset"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 22> kPassThroughOps = {
|
||||
"BatchDataset",
|
||||
"BatchDatasetV2",
|
||||
"ExperimentalMapAndBatchDataset",
|
||||
"PaddedBatchDataset",
|
||||
"PaddedBatchDatasetV2",
|
||||
"CacheDataset",
|
||||
"FilterDataset",
|
||||
"FilterByLastComponentDataset",
|
||||
"Identity",
|
||||
"MapDataset",
|
||||
"ModelDataset",
|
||||
"OptimizeDataset",
|
||||
"ParallelMapDataset",
|
||||
"PrefetchDataset",
|
||||
"ReduceDataset",
|
||||
"RepeatDataset",
|
||||
"ShardDataset",
|
||||
"ShuffleAndRepeatDataset",
|
||||
"ShuffleDataset",
|
||||
"SkipDataset",
|
||||
"TakeDataset",
|
||||
"WindowDataset"
|
||||
};
|
||||
|
||||
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
|
||||
constexpr std::array<const char*, 4> kFuncDatasetOps = {
|
||||
"ExperimentalParallelInterleaveDataset",
|
||||
"FlatMapDataset",
|
||||
"InterleaveDataset",
|
||||
"ParallelInterleaveDatasetV2"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
|
||||
"GeneratorDataset",
|
||||
"RangeDataset",
|
||||
"SparseTensorsSliceDataset",
|
||||
"TensorDataset",
|
||||
"TensorSliceDataset",
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
||||
GraphDef* output);
|
||||
|
||||
template <std::size_t SIZE>
|
||||
bool IsDatasetNodeOfType(const NodeDef& node,
|
||||
const std::array<const char*, SIZE>& arr) {
|
||||
for (const auto& dataset_op_name : arr) {
|
||||
if (node.op() == dataset_op_name) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
|
||||
int64 num_workers, int64 index) {
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kShardDatasetOpName);
|
||||
graph_utils::SetUniqueGraphNodeName(kShardDatasetOpName, graph->graph(),
|
||||
&new_node);
|
||||
|
||||
// Construct argument nodes
|
||||
NodeDef* num_shards_node =
|
||||
graph_utils::AddScalarConstNode<int64>(num_workers, graph);
|
||||
NodeDef* index_node = graph_utils::AddScalarConstNode<int64>(index, graph);
|
||||
|
||||
// Add inputs to new node
|
||||
new_node.add_input(add_before.input(0));
|
||||
new_node.add_input(num_shards_node->name());
|
||||
new_node.add_input(index_node->name());
|
||||
|
||||
// Add shapes and other attributes
|
||||
NodeDef* add_after = graph->GetNode(add_before.input(0));
|
||||
graph_utils::CopyAttribute("output_shapes", *add_after, &new_node);
|
||||
|
||||
if (add_after->attr().find("Toutput_types") != add_after->attr().end()) {
|
||||
(*(new_node.mutable_attr()))["output_types"] =
|
||||
add_after->attr().at("Toutput_types");
|
||||
} else {
|
||||
graph_utils::CopyAttribute("output_types", *add_after, &new_node);
|
||||
}
|
||||
|
||||
// Add new node into graph and update edges
|
||||
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ReaderOpInFunction(const NodeDef& node,
|
||||
const FunctionLibraryDefinition& flib) {
|
||||
const FunctionDef* func = flib.Find(node.attr().at("f").func().name());
|
||||
for (int i = 0; i < func->node_def_size(); i++) {
|
||||
NodeDef node_in_func = func->node_def(i);
|
||||
if (IsDatasetNodeOfType(node_in_func, kReaderDatasetOps) &&
|
||||
node_in_func.input_size() > 0 &&
|
||||
str_util::StartsWith(node_in_func.input(0), "args_0")) {
|
||||
return true;
|
||||
}
|
||||
if (IsDatasetNodeOfType(func->node_def(i), kFuncDatasetOps) &&
|
||||
ReaderOpInFunction(func->node_def(i), flib)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
|
||||
absl::flat_hash_set<string>* nodes_to_delete) {
|
||||
if (node.op() == kShuffleDatasetOpName) {
|
||||
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
|
||||
nodes_to_delete->insert(node.name());
|
||||
}
|
||||
|
||||
for (const auto& fanin : graph->GetFanins(node, true)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RemoveShuffleDataset(graph, *fanin.node, nodes_to_delete));
|
||||
}
|
||||
|
||||
// TODO(frankchn): Traverse functions too.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
||||
FunctionLibraryDefinition* flib,
|
||||
MutableGraphView* graph,
|
||||
absl::flat_hash_set<string>* nodes_to_delete) {
|
||||
if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
|
||||
return errors::NotFound("Found an unshardable source dataset: ",
|
||||
node.DebugString());
|
||||
}
|
||||
|
||||
if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
|
||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, index,
|
||||
flib, graph, nodes_to_delete));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This handles the case where a reader Dataset is contained within a
|
||||
// FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
|
||||
//
|
||||
// dataset = Dataset.list_files("/path/to/data")
|
||||
// dataset = dataset.flat_map(core_readers.TFRecordDataset)
|
||||
//
|
||||
// where the list of files is passed in one-by-one as an argument to the
|
||||
// function in flat_map.
|
||||
if (IsDatasetNodeOfType(node, kFuncDatasetOps) &&
|
||||
ReaderOpInFunction(node, *flib)) {
|
||||
TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
|
||||
TF_RETURN_IF_ERROR(RemoveShuffleDataset(graph, node, nodes_to_delete));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (IsDatasetNodeOfType(node, kReaderDatasetOps)) {
|
||||
// We reached a reader dataset directly and we try to shard input 0.
|
||||
TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
|
||||
TF_RETURN_IF_ERROR(RemoveShuffleDataset(graph, node, nodes_to_delete));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (!IsDatasetNodeOfType(node, kPassThroughOps)) {
|
||||
return errors::NotFound(
|
||||
"Did not find a shardable source, walked to ",
|
||||
"a node which is not a dataset: ", node.DebugString());
|
||||
}
|
||||
|
||||
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
|
||||
return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
|
||||
nodes_to_delete);
|
||||
}
|
||||
|
||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
||||
GraphDef* output) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
|
||||
|
||||
NodeDef target_node;
|
||||
absl::flat_hash_set<string> nodes_to_delete;
|
||||
|
||||
// The basic approach here is to walk the graph from sink to source, and find
|
||||
// the latest occurrence of a ReaderDataset (e.g. CSVDataset, TFRecordDataset,
|
||||
// etc...). We then add a shard after that dataset to shard the outputs of
|
||||
// that dataset, in effect giving a piece to each worker. Finally, we remove
|
||||
// occurences from randomness from before that point in the graph (e.g. things
|
||||
// like ShuffleDataset) to ensure that `shard` returns a sensible result.
|
||||
|
||||
NodeDef sink_node;
|
||||
TF_RETURN_IF_ERROR(graph_utils::FindSinkNode(item.graph, &sink_node));
|
||||
TF_RETURN_IF_ERROR(RecursivelyHandleOp(sink_node, num_workers, index, &flib,
|
||||
&graph, &nodes_to_delete));
|
||||
|
||||
TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status AutoShard::Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
|
||||
if (!config) return errors::InvalidArgument("RewriterConfig not found.");
|
||||
|
||||
if ((config->parameter_map().find("num_workers") ==
|
||||
config->parameter_map().end())) {
|
||||
return errors::InvalidArgument("num_workers parameter missing.");
|
||||
}
|
||||
|
||||
if ((config->parameter_map().find("index") ==
|
||||
config->parameter_map().end())) {
|
||||
return errors::InvalidArgument("index parameter missing.");
|
||||
}
|
||||
|
||||
num_workers_ = config->parameter_map().at("num_workers").i();
|
||||
index_ = config->parameter_map().at("index").i();
|
||||
|
||||
if (num_workers_ < 1) {
|
||||
return errors::InvalidArgument("num_workers should be >= 1, currently ",
|
||||
num_workers_);
|
||||
}
|
||||
|
||||
if (index_ < 0 || index_ >= num_workers_) {
|
||||
return errors::InvalidArgument("index should be >= 0 and < ", num_workers_,
|
||||
", currently ", index_);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutoShard::OptimizeAndCollectStats(Cluster* /* cluster */,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
|
||||
TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_, output));
|
||||
stats->num_changes++;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(AutoShard, "tf_auto_shard");
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
53
tensorflow/core/grappler/optimizers/data/auto_shard.h
Normal file
53
tensorflow/core/grappler/optimizers/data/auto_shard.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// AutoShard takes a Dataset graph and tries to insert a shard node
|
||||
// automatically before a ReaderDataset (e.g. a CSVDataset or a TFRecordDataset)
|
||||
// such that the dataset is sharded without any modifications to the original
|
||||
// dataset-based input pipeline.
|
||||
class AutoShard : public TFDataOptimizerBase {
|
||||
public:
|
||||
AutoShard() = default;
|
||||
~AutoShard() override = default;
|
||||
|
||||
string name() const override { return "tf_auto_shard"; }
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override;
|
||||
|
||||
Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override {}
|
||||
|
||||
private:
|
||||
int64 num_workers_;
|
||||
int64 index_;
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_
|
@ -300,6 +300,40 @@ Status EnsureNodeNamesUnique(Graph* g) {
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Tries to find a Sink node in the graph. A sink node is defined as a node
|
||||
// that has at least one input and no outputs. If there are multiple of these,
|
||||
// this might return any one of them. This is useful to identify the final
|
||||
// Dataset op in the graph but in some cases there might be multiple Identity
|
||||
// ops added to the end and this would return the last Identity op in that case.
|
||||
|
||||
Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node) {
|
||||
absl::flat_hash_map<string, int> all_node_names;
|
||||
absl::flat_hash_map<string, int> node_input_map;
|
||||
for (int i = 0; i < graph_def.node_size(); ++i) {
|
||||
all_node_names.insert_or_assign(graph_def.node(i).name(), i);
|
||||
node_input_map.insert_or_assign(graph_def.node(i).name(), 0);
|
||||
}
|
||||
// Counts how many graph nodes for each input name. Candidate sink
|
||||
// nodes are ones which are inputs into zero nodes.
|
||||
for (const NodeDef& node : graph_def.node()) {
|
||||
for (const string& input_name : node.input()) {
|
||||
node_input_map[input_name]++;
|
||||
}
|
||||
}
|
||||
for (const auto& it : node_input_map) {
|
||||
if (it.second == 0) {
|
||||
const NodeDef& sink_graph_node = graph_def.node(all_node_names[it.first]);
|
||||
if (sink_graph_node.input_size() == 0) {
|
||||
continue;
|
||||
}
|
||||
*sink_node = sink_graph_node;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return errors::InvalidArgument("Failed to find a sink node");
|
||||
}
|
||||
|
||||
} // namespace graph_utils
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -144,6 +144,9 @@ void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
|
||||
// and renaming nodes does not mutate any edges.
|
||||
Status EnsureNodeNamesUnique(Graph* g);
|
||||
|
||||
// Returns the sink node (i.e. last node) in the graph.
|
||||
Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node);
|
||||
|
||||
} // namespace graph_utils
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -270,6 +270,40 @@ TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
|
||||
EXPECT_NE(const_0->name(), const_2->name());
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, TestFindSinkNodeStandard) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
AddNode("node1", "Identity", {}, {}, &graph);
|
||||
AddNode("node2", "Identity", {"node1"}, {}, &graph);
|
||||
NodeDef* node3 = AddNode("node3", "Identity", {"node2"}, {}, &graph);
|
||||
|
||||
NodeDef sink_node;
|
||||
TF_EXPECT_OK(FindSinkNode(graph_def, &sink_node));
|
||||
EXPECT_EQ(sink_node.name(), node3->name());
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, TestFindSinkNodeNoSingleSink) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
AddNode("node1", "Identity", {}, {}, &graph);
|
||||
AddNode("node2", "Identity", {}, {}, &graph);
|
||||
|
||||
NodeDef sink_node;
|
||||
Status s = FindSinkNode(graph_def, &sink_node);
|
||||
EXPECT_FALSE(s.ok());
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, TestFindSinkNodeGraphDefEmpty) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
NodeDef sink_node;
|
||||
Status s = FindSinkNode(graph_def, &sink_node);
|
||||
EXPECT_FALSE(s.ok());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace graph_utils
|
||||
} // namespace grappler
|
||||
|
@ -172,39 +172,6 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// There is one Sink node at least that is added to the end of the graph. We
|
||||
// find that node and return it. It is possible that there are multiple
|
||||
// Identity ops from the final Dataset op to that Sink node, but the recursive
|
||||
// graph traversal handles that.
|
||||
Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node) {
|
||||
absl::flat_hash_map<string, int> all_node_names;
|
||||
absl::flat_hash_map<string, int> node_input_map;
|
||||
for (int i = 0; i < graph_def.node_size(); ++i) {
|
||||
all_node_names.insert_or_assign(graph_def.node(i).name(), i);
|
||||
node_input_map.insert_or_assign(graph_def.node(i).name(), 0);
|
||||
}
|
||||
// Counts how many graph nodes is this node the input to. Candidate sink
|
||||
// nodes are ones which are inputs into zero nodes.
|
||||
for (const NodeDef& node : graph_def.node()) {
|
||||
for (const string& input_name : node.input()) {
|
||||
node_input_map[input_name]++;
|
||||
}
|
||||
}
|
||||
for (const auto& it : node_input_map) {
|
||||
if (it.second == 0) {
|
||||
const NodeDef& sink_graph_node = graph_def.node(all_node_names[it.first]);
|
||||
// Sometimes the searching surfaces Arg nodes in function cases that
|
||||
// have no input. This check rejects those.
|
||||
if (sink_graph_node.input_size() == 0) {
|
||||
continue;
|
||||
}
|
||||
*sink_node = sink_graph_node;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return errors::InvalidArgument("Failed to find a sink node");
|
||||
}
|
||||
|
||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
||||
GraphDef* output);
|
||||
|
||||
@ -282,7 +249,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
|
||||
|
||||
NodeDef sink_node;
|
||||
TF_RETURN_IF_ERROR(FindSinkNode(item.graph, &sink_node));
|
||||
TF_RETURN_IF_ERROR(graph_utils::FindSinkNode(item.graph, &sink_node));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RecursivelyHandleOp(sink_node, num_workers, &flib, &graph));
|
||||
*output->mutable_library() = flib.ToProto();
|
||||
|
@ -21,10 +21,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
constexpr float kLayerByLayerTreeWeight = 1.0;
|
||||
} // namespace
|
||||
|
||||
// Constructor.
|
||||
BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
|
||||
: tree_ensemble_(
|
||||
|
@ -129,6 +129,29 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "unbounded_thread_pool",
|
||||
srcs = ["unbounded_thread_pool.cc"],
|
||||
hdrs = ["unbounded_thread_pool.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "unbounded_thread_pool_test",
|
||||
srcs = ["unbounded_thread_pool_test.cc"],
|
||||
deps = [
|
||||
":unbounded_thread_pool",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "window_dataset",
|
||||
srcs = ["window_dataset.cc"],
|
||||
@ -595,6 +618,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":dataset_utils",
|
||||
":optional_ops",
|
||||
":unbounded_thread_pool",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
@ -612,6 +636,7 @@ tf_kernel_library(
|
||||
srcs = ["multi_device_iterator_ops.cc"],
|
||||
deps = [
|
||||
":dataset_utils",
|
||||
":unbounded_thread_pool",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -54,6 +54,21 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "auto_shard_dataset_op",
|
||||
srcs = ["auto_shard_dataset_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/optimizers/data:auto_shard",
|
||||
"//tensorflow/core/kernels/data:graph_rewrite_dataset",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "group_by_reducer_dataset_op",
|
||||
srcs = ["group_by_reducer_dataset_op.cc"],
|
||||
@ -390,6 +405,7 @@ tf_kernel_library(
|
||||
name = "dataset_kernels",
|
||||
deps = [
|
||||
":assert_next_dataset_op",
|
||||
":auto_shard_dataset_op",
|
||||
":choose_fastest_dataset_op",
|
||||
":csv_dataset_op",
|
||||
":dense_to_sparse_batch_dataset_op",
|
||||
|
@ -0,0 +1,118 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kOptimizerName[] = "tf_auto_shard";
|
||||
|
||||
class AutoShardDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit AutoShardDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx),
|
||||
graph_def_version_(ctx->graph_def_version()) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
int64 index;
|
||||
int64 num_workers;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers));
|
||||
OP_REQUIRES(
|
||||
ctx, num_workers > 0,
|
||||
errors::InvalidArgument("num_workers must be greater than zero."));
|
||||
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "index", &index));
|
||||
OP_REQUIRES(ctx, index >= 0 && index < num_workers,
|
||||
errors::InvalidArgument("index must be between 0 and ",
|
||||
num_workers - 1));
|
||||
|
||||
Dataset* dataset = new Dataset(ctx, input, num_workers, index,
|
||||
output_types_, output_shapes_);
|
||||
const Status s = dataset->Optimize(ctx);
|
||||
|
||||
if (s.ok()) {
|
||||
*output = dataset;
|
||||
} else {
|
||||
dataset->Unref();
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public GraphRewriteDataset {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
const int64 num_workers, const int64 index,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
|
||||
num_workers_(num_workers),
|
||||
index_(index) {}
|
||||
|
||||
string DebugString() const override {
|
||||
return "AutoShardDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
private:
|
||||
bool ShouldOptimizeFunctions() override {
|
||||
// We only want to optimize functions for some particular datasets like
|
||||
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
|
||||
// function optimization and explicitly handle function modifications
|
||||
// for those datasets in the rewrite.
|
||||
return false;
|
||||
}
|
||||
|
||||
RewriterConfig CreateGrapplerRewriteConfig() override {
|
||||
RewriterConfig rewriter_config;
|
||||
rewriter_config.set_fail_on_optimizer_errors(true);
|
||||
rewriter_config.add_optimizers(kOptimizerName);
|
||||
rewriter_config.set_meta_optimizer_iterations(
|
||||
RewriterConfig_NumIterationsType_ONE);
|
||||
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||
custom_optimizer->set_name(kOptimizerName);
|
||||
AttrValue num_workers_attr;
|
||||
num_workers_attr.set_i(num_workers_);
|
||||
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
|
||||
num_workers_attr;
|
||||
|
||||
AttrValue index_attr;
|
||||
index_attr.set_i(index_);
|
||||
(*custom_optimizer->mutable_parameter_map())["index"] = index_attr;
|
||||
|
||||
return rewriter_config;
|
||||
}
|
||||
|
||||
const int64 num_workers_;
|
||||
const int64 index_;
|
||||
};
|
||||
|
||||
const int graph_def_version_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalAutoShardDataset").Device(DEVICE_CPU),
|
||||
AutoShardDatasetOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -292,10 +292,10 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
|
||||
for (size_t i = 0, num_inputs = dataset()->inputs_.size();
|
||||
i < num_inputs; ++i) {
|
||||
threads[i].result = absl::make_unique<InvocationResult>();
|
||||
threads[i].thread.reset(ctx->env()->StartThread(
|
||||
{}, strings::StrCat("tf_data_merge_", i),
|
||||
threads[i].thread = ctx->StartThread(
|
||||
strings::StrCat("tf_data_merge_", i),
|
||||
std::bind(&ChooseFastestIterator::RunnerThread, this, ctx,
|
||||
threads[i].result.get(), i)));
|
||||
threads[i].result.get(), i));
|
||||
}
|
||||
return threads;
|
||||
}
|
||||
|
@ -514,9 +514,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_.reset(ctx->env()->StartThread(
|
||||
{}, "tf_data_map_and_batch",
|
||||
std::bind(&Iterator::RunnerThread, this, ctx_copy)));
|
||||
runner_thread_ = ctx->StartThread(
|
||||
"tf_data_map_and_batch",
|
||||
std::bind(&Iterator::RunnerThread, this, ctx_copy));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -926,8 +926,8 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (!new_ctx) {
|
||||
new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
}
|
||||
workers_[i]->threads.emplace_back(ctx->env()->StartThread(
|
||||
{}, strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j),
|
||||
workers_[i]->threads.emplace_back(ctx->StartThread(
|
||||
strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j),
|
||||
[this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); }));
|
||||
VLOG(3) << "Worker " << i << ", " << j << " successfully started.";
|
||||
}
|
||||
@ -936,9 +936,9 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (!new_ctx) {
|
||||
new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
}
|
||||
runner_thread_.reset(ctx->env()->StartThread(
|
||||
{}, "tf_data_numa_map_and_batch",
|
||||
[this, new_ctx] { RunnerThread(new_ctx); }));
|
||||
runner_thread_ =
|
||||
ctx->StartThread("tf_data_numa_map_and_batch",
|
||||
[this, new_ctx] { RunnerThread(new_ctx); });
|
||||
}
|
||||
VLOG(3) << "All workers & runner thread started.";
|
||||
return Status::OK();
|
||||
|
@ -493,8 +493,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
worker_threads_.reserve(dataset()->num_threads());
|
||||
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
|
||||
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
|
||||
worker_threads_.emplace_back(ctx->env()->StartThread(
|
||||
{}, strings::StrCat("tf_data_parallel_interleave_worker_", i),
|
||||
worker_threads_.emplace_back(ctx->StartThread(
|
||||
strings::StrCat("tf_data_parallel_interleave_worker_", i),
|
||||
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
|
||||
}
|
||||
}
|
||||
@ -592,8 +592,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
workers_[i].SetInputs(s, std::move(args));
|
||||
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
|
||||
worker_threads_.emplace_back(ctx->env()->StartThread(
|
||||
{}, strings::StrCat("tf_data_parallel_interleave_worker_", i),
|
||||
worker_threads_.push_back(ctx->StartThread(
|
||||
strings::StrCat("tf_data_parallel_interleave_worker_", i),
|
||||
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
|
||||
if (i < dataset()->cycle_length_) {
|
||||
interleave_indices_.push_back(i);
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/optional_ops.h"
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
@ -51,14 +53,15 @@ const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
|
||||
|
||||
class IteratorResource : public ResourceBase {
|
||||
public:
|
||||
IteratorResource(const DataTypeVector& output_dtypes,
|
||||
IteratorResource(Env* env, const DataTypeVector& output_dtypes,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
const int /*unused: graph_def_version*/,
|
||||
std::unique_ptr<DeviceMgr> device_mgr,
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib)
|
||||
: device_mgr_(std::move(device_mgr)),
|
||||
: unbounded_thread_pool_(env, "tf_data_iterator_resource"),
|
||||
device_mgr_(std::move(device_mgr)),
|
||||
iterator_state_(std::make_shared<State>(
|
||||
std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */)),
|
||||
output_dtypes_(output_dtypes),
|
||||
@ -77,6 +80,7 @@ class IteratorResource : public ResourceBase {
|
||||
params.function_handle_cache =
|
||||
captured_state->function_handle_cache.get();
|
||||
params.resource_mgr = &captured_state->resource_mgr;
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
return captured_state->iterator->GetNext(
|
||||
IteratorContext(std::move(params)), out_tensors, end_of_sequence);
|
||||
} else {
|
||||
@ -163,6 +167,8 @@ class IteratorResource : public ResourceBase {
|
||||
params.lib = new_state->lib;
|
||||
params.function_handle_cache = new_state->function_handle_cache.get();
|
||||
params.resource_mgr = &new_state->resource_mgr;
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
|
||||
"Iterator", &new_state->iterator));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -179,6 +185,7 @@ class IteratorResource : public ResourceBase {
|
||||
params.allocator_getter = [device](AllocatorAttributes attrs) {
|
||||
return device->GetAllocator(attrs);
|
||||
};
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
TF_RETURN_IF_ERROR(new_state->iterator->Restore(&iter_ctx, reader));
|
||||
}
|
||||
@ -233,6 +240,7 @@ class IteratorResource : public ResourceBase {
|
||||
params.lib = new_state->lib;
|
||||
params.function_handle_cache = new_state->function_handle_cache.get();
|
||||
params.resource_mgr = &new_state->resource_mgr;
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
|
||||
"Iterator", &iterator));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -284,6 +292,7 @@ class IteratorResource : public ResourceBase {
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
};
|
||||
|
||||
UnboundedThreadPool unbounded_thread_pool_;
|
||||
mutex mu_;
|
||||
const std::unique_ptr<DeviceMgr> device_mgr_ GUARDED_BY(mu_);
|
||||
std::shared_ptr<State> iterator_state_ GUARDED_BY(mu_);
|
||||
@ -432,14 +441,14 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
|
||||
context,
|
||||
mgr->LookupOrCreate<IteratorResource>(
|
||||
cinfo_.container(), cinfo_.name(), &resource,
|
||||
[lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = new IteratorResource(
|
||||
output_dtypes_, output_shapes_, graph_def_version_,
|
||||
std::move(device_mgr), std::move(flib_def),
|
||||
std::move(pflr), lib);
|
||||
return Status::OK();
|
||||
}));
|
||||
[context, lib, &device_mgr, &flib_def, &pflr,
|
||||
this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = new IteratorResource(
|
||||
context->env(), output_dtypes_, output_shapes_,
|
||||
graph_def_version_, std::move(device_mgr),
|
||||
std::move(flib_def), std::move(pflr), lib);
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
Status s = VerifyResource(resource);
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
@ -522,7 +531,7 @@ void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) {
|
||||
existing_resource->Unref();
|
||||
}
|
||||
IteratorResource* new_resource = new IteratorResource(
|
||||
output_dtypes_, output_shapes_, graph_def_version_,
|
||||
context->env(), output_dtypes_, output_shapes_, graph_def_version_,
|
||||
std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
|
||||
// Create the resource with our chosen name under the resource lookup
|
||||
// mutex to avoid another kernel racily creating a resource with this
|
||||
@ -837,11 +846,12 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->resource_manager()->LookupOrCreate<IteratorResource>(
|
||||
cinfo->container(), cinfo->name(), iterator,
|
||||
[lib, this, &flib_def, &pflr](IteratorResource** ret)
|
||||
[ctx, lib, this, &flib_def, &pflr](IteratorResource** ret)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = new IteratorResource(
|
||||
output_dtypes_, output_shapes_, graph_def_version_,
|
||||
nullptr, std::move(flib_def), std::move(pflr), lib);
|
||||
ctx->env(), output_dtypes_, output_shapes_,
|
||||
graph_def_version_, nullptr, std::move(flib_def),
|
||||
std::move(pflr), lib);
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
|
@ -140,9 +140,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (!optimize_thread_) {
|
||||
std::shared_ptr<IteratorContext> new_ctx =
|
||||
std::make_shared<IteratorContext>(*ctx);
|
||||
optimize_thread_.reset(ctx->env()->StartThread(
|
||||
{}, "tf_data_model",
|
||||
[this, new_ctx]() { OptimizeThread(new_ctx); }));
|
||||
optimize_thread_ = ctx->StartThread(
|
||||
"tf_data_model", [this, new_ctx]() { OptimizeThread(new_ctx); });
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
@ -42,14 +43,15 @@ using MultiDeviceIteratorCallback =
|
||||
class MultiDeviceIterator : public ResourceBase {
|
||||
public:
|
||||
MultiDeviceIterator(
|
||||
const DataTypeVector& output_types,
|
||||
Env* env, const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
const std::vector<string>& devices,
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib,
|
||||
std::unique_ptr<FunctionHandleCache> function_handle_cache)
|
||||
: output_types_(output_types),
|
||||
: unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"),
|
||||
output_types_(output_types),
|
||||
output_shapes_(output_shapes),
|
||||
devices_(devices),
|
||||
flib_def_(std::move(flib_def)),
|
||||
@ -82,27 +84,25 @@ class MultiDeviceIterator : public ResourceBase {
|
||||
*incarnation_id = incarnation_id_;
|
||||
|
||||
multi_device_buffer_ = absl::make_unique<MultiDeviceBuffer>(
|
||||
devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator));
|
||||
devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator),
|
||||
this);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GetNextFromShard(IteratorContext* ctx, int shard_num,
|
||||
void GetNextFromShard(OpKernelContext* ctx, int shard_num,
|
||||
int64 incarnation_id,
|
||||
MultiDeviceIteratorCallback callback) {
|
||||
if (ctx->lib() == lib_) {
|
||||
tf_shared_lock l(mu_);
|
||||
multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
|
||||
std::move(callback));
|
||||
} else {
|
||||
IteratorContext::Params params(ctx);
|
||||
params.lib = lib_;
|
||||
params.function_handle_cache = function_handle_cache_.get();
|
||||
params.resource_mgr = &resource_mgr_;
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
tf_shared_lock l(mu_);
|
||||
multi_device_buffer_->GetNextFromShard(
|
||||
&iter_ctx, shard_num, incarnation_id, std::move(callback));
|
||||
}
|
||||
tf_shared_lock l(mu_);
|
||||
IteratorContext::Params params(ctx);
|
||||
params.function_library = lib_def_;
|
||||
params.lib = lib_;
|
||||
params.function_handle_cache = function_handle_cache_.get();
|
||||
params.resource_mgr = &resource_mgr_;
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
|
||||
std::move(callback));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_types() const { return output_types_; }
|
||||
@ -133,12 +133,14 @@ class MultiDeviceIterator : public ResourceBase {
|
||||
class MultiDeviceBuffer {
|
||||
public:
|
||||
MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
|
||||
std::unique_ptr<IteratorBase> host_iterator)
|
||||
std::unique_ptr<IteratorBase> host_iterator,
|
||||
MultiDeviceIterator* parent)
|
||||
: buffer_(size),
|
||||
size_(size),
|
||||
max_buffer_size_(max_buffer_size),
|
||||
incarnation_id_(incarnation_id),
|
||||
host_iterator_(std::move(host_iterator)) {}
|
||||
host_iterator_(std::move(host_iterator)),
|
||||
parent_(parent) {}
|
||||
|
||||
~MultiDeviceBuffer() {
|
||||
{
|
||||
@ -217,10 +219,12 @@ class MultiDeviceIterator : public ResourceBase {
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (!background_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
background_thread_ = absl::WrapUnique<Thread>(ctx->env()->StartThread(
|
||||
{}, "tf_data_multi_device_iterator",
|
||||
std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
|
||||
this, std::move(ctx_copy))));
|
||||
background_thread_ =
|
||||
parent_->unbounded_thread_pool_.get_thread_factory()->StartThread(
|
||||
"tf_data_multi_device_iterator",
|
||||
std::bind(
|
||||
&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
|
||||
this, std::move(ctx_copy)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -342,8 +346,10 @@ class MultiDeviceIterator : public ResourceBase {
|
||||
const int64 max_buffer_size_;
|
||||
const int64 incarnation_id_;
|
||||
const std::unique_ptr<IteratorBase> host_iterator_;
|
||||
MultiDeviceIterator* const parent_; // Not owned.
|
||||
};
|
||||
|
||||
UnboundedThreadPool unbounded_thread_pool_;
|
||||
mutex mu_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
@ -413,8 +419,9 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
current_id_.fetch_add(1));
|
||||
container_name = "AnonymousMultiDeviceIterator";
|
||||
resource = new MultiDeviceIterator(
|
||||
output_types_, output_shapes_, devices_, std::move(flib_def),
|
||||
std::move(pflr), lib, std::move(function_handle_cache));
|
||||
context->env(), output_types_, output_shapes_, devices_,
|
||||
std::move(flib_def), std::move(pflr), lib,
|
||||
std::move(function_handle_cache));
|
||||
// NOTE: `mgr->Create()` transfers the one reference on `resource` to
|
||||
// `mgr`.
|
||||
OP_REQUIRES_OK(context, mgr->Create<MultiDeviceIterator>(
|
||||
@ -425,11 +432,12 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
mgr->LookupOrCreate<MultiDeviceIterator>(
|
||||
container_name, unique_name, &resource,
|
||||
[this, lib, &flib_def, &pflr,
|
||||
[this, context, lib, &flib_def, &pflr,
|
||||
&function_handle_cache](MultiDeviceIterator** ret)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = new MultiDeviceIterator(
|
||||
output_types_, output_shapes_, devices_,
|
||||
context->env(), output_types_,
|
||||
output_shapes_, devices_,
|
||||
std::move(flib_def), std::move(pflr),
|
||||
lib, std::move(function_handle_cache));
|
||||
return Status::OK();
|
||||
@ -557,11 +565,8 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
|
||||
},
|
||||
std::placeholders::_1, std::move(done));
|
||||
|
||||
IteratorContext::Params params(ctx);
|
||||
params.function_library = iterator->function_library();
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
|
||||
callback);
|
||||
iterator->GetNextFromShard(ctx, shard_num, incarnation_id,
|
||||
std::move(callback));
|
||||
iterator->Unref();
|
||||
},
|
||||
std::move(done)));
|
||||
|
@ -517,17 +517,15 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!current_elements_manager_) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
current_elements_manager_ =
|
||||
absl::WrapUnique<Thread>(ctx->env()->StartThread(
|
||||
{}, "tf_data_parallel_interleave_current",
|
||||
[this, new_ctx]() { CurrentElementsManager(new_ctx); }));
|
||||
current_elements_manager_ = ctx->StartThread(
|
||||
"tf_data_parallel_interleave_current",
|
||||
[this, new_ctx]() { CurrentElementsManager(new_ctx); });
|
||||
}
|
||||
if (!future_elements_manager_) {
|
||||
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
|
||||
future_elements_manager_ =
|
||||
absl::WrapUnique<Thread>(ctx->env()->StartThread(
|
||||
{}, "tf_data_parallel_interleave_future",
|
||||
[this, new_ctx]() { FutureElementsManager(new_ctx); }));
|
||||
future_elements_manager_ = ctx->StartThread(
|
||||
"tf_data_parallel_interleave_future",
|
||||
[this, new_ctx]() { FutureElementsManager(new_ctx); });
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -191,9 +191,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
if (!runner_thread_) {
|
||||
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
|
||||
runner_thread_.reset(ctx->env()->StartThread(
|
||||
{}, "tf_data_parallel_map",
|
||||
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
|
||||
runner_thread_ = ctx->StartThread(
|
||||
"tf_data_parallel_map",
|
||||
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -269,9 +269,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
if (!prefetch_thread_) {
|
||||
std::shared_ptr<IteratorContext> new_ctx =
|
||||
std::make_shared<IteratorContext>(*ctx);
|
||||
prefetch_thread_ = absl::WrapUnique<Thread>(ctx->env()->StartThread(
|
||||
{}, "tf_data_prefetch",
|
||||
[this, new_ctx]() { PrefetchThread(new_ctx); }));
|
||||
prefetch_thread_ = ctx->StartThread(
|
||||
"tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
156
tensorflow/core/kernels/data/unbounded_thread_pool.cc
Normal file
156
tensorflow/core/kernels/data/unbounded_thread_pool.cc
Normal file
@ -0,0 +1,156 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// A lightweight wrapper for creating logical threads in a `UnboundedThreadPool`
|
||||
// that can be shared (e.g.) in an `IteratorContext`.
|
||||
class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory {
|
||||
public:
|
||||
explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {}
|
||||
|
||||
std::unique_ptr<Thread> StartThread(const string& name,
|
||||
std::function<void()> fn) override {
|
||||
return pool_->RunOnPooledThread(std::move(fn));
|
||||
}
|
||||
|
||||
private:
|
||||
UnboundedThreadPool* const pool_; // Not owned.
|
||||
};
|
||||
|
||||
// A logical implementation of the `tensorflow::Thread` interface that uses
|
||||
// physical threads in an `UnboundedThreadPool` to perform the work.
|
||||
//
|
||||
// NOTE: This object represents a logical thread of control that may be mapped
|
||||
// onto the same physical thread as other work items that are submitted to the
|
||||
// same `UnboundedThreadPool`.
|
||||
class UnboundedThreadPool::LogicalThreadWrapper : public Thread {
|
||||
public:
|
||||
explicit LogicalThreadWrapper(std::shared_ptr<Notification> join_notification)
|
||||
: join_notification_(std::move(join_notification)) {}
|
||||
|
||||
~LogicalThreadWrapper() override {
|
||||
// NOTE: The `Thread` destructor is expected to "join" the created thread,
|
||||
// but the physical thread may continue to execute after the work for this
|
||||
// thread is complete. We simulate this by waiting on a notification that
|
||||
// the `CachedThreadFunc` will notify when the thread's work function is
|
||||
// complete.
|
||||
join_notification_->WaitForNotification();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Notification> join_notification_;
|
||||
};
|
||||
|
||||
UnboundedThreadPool::~UnboundedThreadPool() {
|
||||
{
|
||||
mutex_lock l(work_queue_mu_);
|
||||
// Wake up all `CachedThreadFunc` threads and cause them to terminate before
|
||||
// joining them when `threads_` is cleared.
|
||||
cancelled_ = true;
|
||||
work_queue_cv_.notify_all();
|
||||
if (!work_queue_.empty()) {
|
||||
LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was "
|
||||
<< "deleted with pending work in its queue. This may indicate "
|
||||
<< "a potential use-after-free bug.";
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
mutex_lock l(thread_pool_mu_);
|
||||
// Clear the list of pooled threads, which will eventually terminate due to
|
||||
// the previous notification.
|
||||
//
|
||||
// NOTE: It is safe to do this while holding `pooled_threads_mu_`, because
|
||||
// no subsequent calls to `this->StartThread()` should be issued after the
|
||||
// destructor starts.
|
||||
thread_pool_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
|
||||
return std::make_shared<LogicalThreadFactory>(this);
|
||||
}
|
||||
|
||||
size_t UnboundedThreadPool::size() {
|
||||
tf_shared_lock l(thread_pool_mu_);
|
||||
return thread_pool_.size();
|
||||
}
|
||||
|
||||
std::unique_ptr<Thread> UnboundedThreadPool::RunOnPooledThread(
|
||||
std::function<void()> fn) {
|
||||
auto join_notification = std::make_shared<Notification>();
|
||||
bool all_threads_busy;
|
||||
{
|
||||
// Enqueue a work item for the new thread's function, and wake up a
|
||||
// cached thread to process it.
|
||||
mutex_lock l(work_queue_mu_);
|
||||
work_queue_.push_back({std::move(fn), join_notification});
|
||||
work_queue_cv_.notify_one();
|
||||
// NOTE: The queue may be non-empty, so we must account for queued work when
|
||||
// considering how many threads are free.
|
||||
all_threads_busy = work_queue_.size() > num_idle_threads_;
|
||||
}
|
||||
|
||||
if (all_threads_busy) {
|
||||
// Spawn a new physical thread to process the given function.
|
||||
// NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
|
||||
// at the beginning of its work loop.
|
||||
Thread* new_thread = env_->StartThread(
|
||||
{}, thread_name_,
|
||||
std::bind(&UnboundedThreadPool::PooledThreadFunc, this));
|
||||
|
||||
mutex_lock l(thread_pool_mu_);
|
||||
thread_pool_.emplace_back(new_thread);
|
||||
}
|
||||
|
||||
return absl::make_unique<LogicalThreadWrapper>(std::move(join_notification));
|
||||
}
|
||||
|
||||
void UnboundedThreadPool::PooledThreadFunc() {
|
||||
while (true) {
|
||||
WorkItem work_item;
|
||||
{
|
||||
mutex_lock l(work_queue_mu_);
|
||||
++num_idle_threads_;
|
||||
while (!cancelled_ && work_queue_.empty()) {
|
||||
// Wait for a new work function to be submitted, or the cache to be
|
||||
// destroyed.
|
||||
work_queue_cv_.wait(l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
work_item = std::move(work_queue_.front());
|
||||
work_queue_.pop_front();
|
||||
--num_idle_threads_;
|
||||
}
|
||||
|
||||
work_item.work_function();
|
||||
|
||||
// Notify any thread that has "joined" the cached thread for this work item.
|
||||
work_item.done_notification->Notify();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
77
tensorflow/core/kernels/data/unbounded_thread_pool.h
Normal file
77
tensorflow/core/kernels/data/unbounded_thread_pool.h
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/thread_factory.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a
|
||||
// potentially large number of "logical" threads onto a smaller number of
|
||||
// "physical" threads. The multiplexing is achieved by maintaining an internal
|
||||
// pool of long-running "physical" threads that are used to execute the
|
||||
// "logical" threads. Like a regular thread, a "logical" thread may block on
|
||||
// other threads, and the size of the pool will increase to ensure that progress
|
||||
// is made. This mechanism is recommended in situations where short-lived
|
||||
// threads are created repeatedly, to avoid the overhead and memory
|
||||
// fragmentation that can result from excessive thread creation.
|
||||
class UnboundedThreadPool {
|
||||
public:
|
||||
UnboundedThreadPool(Env* env, const string& thread_name)
|
||||
: env_(env), thread_name_(thread_name) {}
|
||||
~UnboundedThreadPool();
|
||||
|
||||
// Returns an implementation of `ThreadFactory` that can be used to create
|
||||
// logical threads in this pool.
|
||||
std::shared_ptr<ThreadFactory> get_thread_factory();
|
||||
|
||||
// Returns the current number of threads in this pool.
|
||||
size_t size();
|
||||
|
||||
private:
|
||||
class LogicalThreadFactory;
|
||||
class LogicalThreadWrapper;
|
||||
struct WorkItem {
|
||||
std::function<void()> work_function;
|
||||
std::shared_ptr<Notification> done_notification;
|
||||
};
|
||||
|
||||
std::unique_ptr<Thread> RunOnPooledThread(std::function<void()> fn);
|
||||
void PooledThreadFunc();
|
||||
|
||||
Env* const env_; // Not owned.
|
||||
const string thread_name_;
|
||||
mutex work_queue_mu_;
|
||||
condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_);
|
||||
size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0;
|
||||
bool cancelled_ GUARDED_BY(work_queue_mu_) = false;
|
||||
std::deque<WorkItem> work_queue_ GUARDED_BY(work_queue_mu_);
|
||||
mutex thread_pool_mu_;
|
||||
std::vector<std::unique_ptr<Thread>> thread_pool_ GUARDED_BY(thread_pool_mu_);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_
|
143
tensorflow/core/kernels/data/unbounded_thread_pool_test.cc
Normal file
143
tensorflow/core/kernels/data/unbounded_thread_pool_test.cc
Normal file
@ -0,0 +1,143 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
TEST(UnboundedThreadPool, SingleThread) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
|
||||
// Create a thread that updates a variable, and ensure that it runs to
|
||||
// completion.
|
||||
std::atomic<int> i(0);
|
||||
auto thread = thread_factory->StartThread("", [&i]() { ++i; });
|
||||
thread.reset();
|
||||
|
||||
EXPECT_GE(pool.size(), 1);
|
||||
EXPECT_EQ(1, i);
|
||||
}
|
||||
|
||||
TEST(UnboundedThreadPool, MultipleThreads) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
|
||||
// Create ten threads that update a variable, and ensure that they all run
|
||||
// to completion.
|
||||
std::vector<std::unique_ptr<Thread>> threads;
|
||||
const int kNumThreadsToCreate = 10;
|
||||
std::atomic<int> i(0);
|
||||
for (int j = 0; j < kNumThreadsToCreate; ++j) {
|
||||
threads.push_back(thread_factory->StartThread("", [&i]() { ++i; }));
|
||||
}
|
||||
threads.clear();
|
||||
|
||||
EXPECT_GE(pool.size(), 1);
|
||||
EXPECT_EQ(i, kNumThreadsToCreate);
|
||||
}
|
||||
|
||||
TEST(UnboundedThreadPool, MultipleThreadsSleepingRandomly) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
|
||||
// Create 1000 threads that sleep for a random period of time then update a
|
||||
// variable, and ensure that they all run to completion.
|
||||
std::vector<std::unique_ptr<Thread>> threads;
|
||||
const int kNumThreadsToCreate = 1000;
|
||||
std::atomic<int> i(0);
|
||||
for (int j = 0; j < kNumThreadsToCreate; ++j) {
|
||||
threads.push_back(thread_factory->StartThread("", [&i]() {
|
||||
Env::Default()->SleepForMicroseconds(random::New64() % 10);
|
||||
++i;
|
||||
}));
|
||||
}
|
||||
threads.clear();
|
||||
|
||||
EXPECT_GE(pool.size(), 1);
|
||||
EXPECT_EQ(i, kNumThreadsToCreate);
|
||||
}
|
||||
|
||||
TEST(UnboundedThreadPool, ConcurrentThreadCreation) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
|
||||
// Create ten threads that each create ten threads that update a variable, and
|
||||
// ensure that they all run to completion.
|
||||
std::vector<std::unique_ptr<Thread>> threads;
|
||||
const int kNumThreadsToCreate = 10;
|
||||
std::atomic<int> i(0);
|
||||
for (int j = 0; j < kNumThreadsToCreate; ++j) {
|
||||
threads.push_back(thread_factory->StartThread("", [&i, thread_factory]() {
|
||||
std::vector<std::unique_ptr<Thread>> nested_threads;
|
||||
for (int k = 0; k < kNumThreadsToCreate; ++k) {
|
||||
nested_threads.push_back(
|
||||
thread_factory->StartThread("", [&i]() { ++i; }));
|
||||
}
|
||||
nested_threads.clear();
|
||||
}));
|
||||
}
|
||||
threads.clear();
|
||||
|
||||
EXPECT_GE(pool.size(), 1);
|
||||
EXPECT_EQ(i, kNumThreadsToCreate * kNumThreadsToCreate);
|
||||
}
|
||||
|
||||
TEST(UnboundedThreadPool, MultipleBlockingThreads) {
|
||||
UnboundedThreadPool pool(Env::Default(), "test");
|
||||
auto thread_factory = pool.get_thread_factory();
|
||||
|
||||
std::vector<std::unique_ptr<Thread>> threads;
|
||||
|
||||
// Create multiple waves (with increasing sizes) of threads that all block
|
||||
// before returning, and
|
||||
// ensure that we create the appropriate number of threads and terminate
|
||||
// correctly.
|
||||
std::vector<int> round_sizes = {5, 10, 15, 20};
|
||||
|
||||
for (const int round_size : round_sizes) {
|
||||
Notification n;
|
||||
BlockingCounter bc(round_size);
|
||||
for (int j = 0; j < round_size; ++j) {
|
||||
threads.push_back(thread_factory->StartThread("", [&bc, &n]() {
|
||||
bc.DecrementCount();
|
||||
// Block until `n` is notified, so that all ten threads must been
|
||||
// created before the first one completes.
|
||||
n.WaitForNotification();
|
||||
}));
|
||||
}
|
||||
|
||||
// Wait until all threads have started. Since the number of threads in each
|
||||
// wave is increasing, we should have at least that number of threads in the
|
||||
// pool.
|
||||
bc.Wait();
|
||||
// NOTE: There is a benign race between a new round starting and the
|
||||
// physical threads from the previous round returning to the pool, so we may
|
||||
// create more threads than the round_size.
|
||||
EXPECT_GE(pool.size(), round_size);
|
||||
n.Notify();
|
||||
threads.clear();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -133,6 +133,17 @@ int VarintLength(uint64_t v) {
|
||||
return len;
|
||||
}
|
||||
|
||||
const char* GetVarint32Ptr(const char* p, const char* limit, uint32* value) {
|
||||
if (p < limit) {
|
||||
uint32 result = *(reinterpret_cast<const unsigned char*>(p));
|
||||
if ((result & 128) == 0) {
|
||||
*value = result;
|
||||
return p + 1;
|
||||
}
|
||||
}
|
||||
return GetVarint32PtrFallback(p, limit, value);
|
||||
}
|
||||
|
||||
const char* GetVarint32PtrFallback(const char* p, const char* limit,
|
||||
uint32* value) {
|
||||
uint32 result = 0;
|
||||
|
@ -55,18 +55,8 @@ extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v);
|
||||
// Internal routine for use by fallback path of GetVarint32Ptr
|
||||
extern const char* GetVarint32PtrFallback(const char* p, const char* limit,
|
||||
uint32* value);
|
||||
inline const char* GetVarint32Ptr(const char* p, const char* limit,
|
||||
uint32* value) {
|
||||
if (p < limit) {
|
||||
uint32 result = *(reinterpret_cast<const unsigned char*>(p));
|
||||
if ((result & 128) == 0) {
|
||||
*value = result;
|
||||
return p + 1;
|
||||
}
|
||||
}
|
||||
return GetVarint32PtrFallback(p, limit, value);
|
||||
}
|
||||
|
||||
extern const char* GetVarint32Ptr(const char* p, const char* limit,
|
||||
uint32* value);
|
||||
extern char* EncodeVarint32(char* dst, uint32 v);
|
||||
extern char* EncodeVarint64(char* dst, uint64 v);
|
||||
|
||||
|
@ -22251,6 +22251,89 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "EnqueueTPUEmbeddingSparseBatch"
|
||||
input_arg {
|
||||
name: "sample_indices"
|
||||
type_attr: "T1"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "embedding_indices"
|
||||
type_attr: "T2"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "aggregation_weights"
|
||||
type_attr: "T3"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "mode_override"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "T1"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T3"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "device_ordinal"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: -1
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "combiners"
|
||||
type: "list(string)"
|
||||
default_value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
||||
input_arg {
|
||||
@ -22299,6 +22382,93 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
||||
input_arg {
|
||||
name: "sample_indices"
|
||||
type_attr: "T1"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "embedding_indices"
|
||||
type_attr: "T2"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "aggregation_weights"
|
||||
type_attr: "T3"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "mode_override"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "T1"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T3"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "device_ordinal"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: -1
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "combiners"
|
||||
type: "list(string)"
|
||||
default_value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "table_ids"
|
||||
type: "list(int)"
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "EnsureShape"
|
||||
input_arg {
|
||||
@ -22814,6 +22984,37 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalAutoShardDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "num_workers"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "index"
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalBytesProducedStatsDataset"
|
||||
input_arg {
|
||||
|
@ -17,6 +17,15 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("ExperimentalAutoShardDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_workers: int64")
|
||||
.Input("index: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalBytesProducedStatsDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("tag: string")
|
||||
|
@ -10421,23 +10421,62 @@ op {
|
||||
name: "EnqueueTPUEmbeddingSparseBatch"
|
||||
input_arg {
|
||||
name: "sample_indices"
|
||||
type: DT_INT32
|
||||
type_attr: "T1"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "embedding_indices"
|
||||
type: DT_INT32
|
||||
type_attr: "T2"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "aggregation_weights"
|
||||
type: DT_FLOAT
|
||||
type_attr: "T3"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "mode_override"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "T1"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T3"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
@ -10465,23 +10504,62 @@ op {
|
||||
name: "EnqueueTPUEmbeddingSparseTensorBatch"
|
||||
input_arg {
|
||||
name: "sample_indices"
|
||||
type: DT_INT32
|
||||
type_attr: "T1"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "embedding_indices"
|
||||
type: DT_INT32
|
||||
type_attr: "T2"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "aggregation_weights"
|
||||
type: DT_FLOAT
|
||||
type_attr: "T3"
|
||||
number_attr: "N"
|
||||
}
|
||||
input_arg {
|
||||
name: "mode_override"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "T1"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T2"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T3"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "N"
|
||||
type: "int"
|
||||
@ -10806,6 +10884,37 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalAutoShardDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "num_workers"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "index"
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ExperimentalBytesProducedStatsDataset"
|
||||
input_arg {
|
||||
|
@ -393,10 +393,13 @@ REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
|
||||
.Input("sample_indices: N * int32")
|
||||
.Input("embedding_indices: N * int32")
|
||||
.Input("aggregation_weights: N * float32")
|
||||
.Input("sample_indices: N * T1")
|
||||
.Input("embedding_indices: N * T2")
|
||||
.Input("aggregation_weights: N * T3")
|
||||
.Input("mode_override: string")
|
||||
.Attr("T1: {int32,int64} = DT_INT32")
|
||||
.Attr("T2: {int32,int64} = DT_INT32")
|
||||
.Attr("T3: {float32,float64} = DT_FLOAT")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.Attr("combiners: list(string) = []")
|
||||
@ -416,10 +419,13 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
|
||||
});
|
||||
|
||||
REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
|
||||
.Input("sample_indices: N * int32")
|
||||
.Input("embedding_indices: N * int32")
|
||||
.Input("aggregation_weights: N * float32")
|
||||
.Input("sample_indices: N * T1")
|
||||
.Input("embedding_indices: N * T2")
|
||||
.Input("aggregation_weights: N * T3")
|
||||
.Input("mode_override: string")
|
||||
.Attr("T1: {int32,int64} = DT_INT32")
|
||||
.Attr("T2: {int32,int64} = DT_INT32")
|
||||
.Attr("T3: {float32,float64} = DT_FLOAT")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("device_ordinal: int = -1")
|
||||
.Attr("combiners: list(string) = []")
|
||||
|
@ -70,8 +70,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:platform_base",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/time",
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,9 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
|
||||
#define TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
|
||||
|
||||
#include "google/protobuf/duration.pb.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -60,20 +58,6 @@ class StringErrorCollector : public protobuf::io::ErrorCollector {
|
||||
const int index_offset_;
|
||||
};
|
||||
|
||||
// Converts an absl::Duration to a google::protobuf::Duration.
|
||||
inline google::protobuf::Duration ToDurationProto(absl::Duration duration) {
|
||||
google::protobuf::Duration proto;
|
||||
proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration));
|
||||
proto.set_nanos(
|
||||
absl::IDivDuration(duration, absl::Nanoseconds(1), &duration));
|
||||
return proto;
|
||||
}
|
||||
|
||||
// Converts a google::protobuf::Duration to an absl::Duration.
|
||||
inline absl::Duration FromDurationProto(google::protobuf::Duration proto) {
|
||||
return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
|
||||
}
|
||||
|
||||
} // namespace proto_utils
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -7570,6 +7570,21 @@ func OptionalGetValue(scope *Scope, optional tf.Output, output_types []tf.DataTy
|
||||
return components
|
||||
}
|
||||
|
||||
// Returns true if and only if the given Optional variant has a value.
|
||||
func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "OptionalHasValue",
|
||||
Input: []tf.Input{
|
||||
optional,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Outputs a tensor containing the reduction across all input tensors.
|
||||
//
|
||||
// Outputs a tensor containing the reduction across all input tensors passed to ops
|
||||
@ -10562,6 +10577,38 @@ func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output
|
||||
return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
|
||||
}
|
||||
|
||||
// Creates a dataset that shards the input dataset.
|
||||
//
|
||||
// Creates a dataset that shards the input dataset by num_workers, returning a
|
||||
// sharded dataset for the index-th worker. This attempts to automatically shard
|
||||
// a dataset by examining the Dataset graph and inserting a shard op before the
|
||||
// inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset).
|
||||
//
|
||||
// This dataset will throw a NotFound error if we cannot shard the dataset
|
||||
// automatically.
|
||||
//
|
||||
// Arguments:
|
||||
// input_dataset: A variant tensor representing the input dataset.
|
||||
// num_workers: A scalar representing the number of workers to distribute this dataset across.
|
||||
// index: A scalar representing the index of the current worker out of num_workers.
|
||||
//
|
||||
//
|
||||
func ExperimentalAutoShardDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "ExperimentalAutoShardDataset",
|
||||
Input: []tf.Input{
|
||||
input_dataset, num_workers, index,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
|
||||
type RandomStandardNormalAttr func(optionalAttr)
|
||||
|
||||
@ -38809,18 +38856,3 @@ func OptionalNone(scope *Scope) (optional tf.Output) {
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Returns true if and only if the given Optional variant has a value.
|
||||
func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "OptionalHasValue",
|
||||
Input: []tf.Input{
|
||||
optional,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
@ -33,13 +33,12 @@ tf_cuda_library(
|
||||
"//tensorflow:android": [],
|
||||
"//conditions:default": ["."],
|
||||
}),
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
] + select({
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:ops",
|
||||
|
@ -33,7 +33,7 @@ GEMMLOWP_URL="https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37
|
||||
FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz"
|
||||
CMSIS_URL="https://github.com/ARM-software/CMSIS_5/archive/5.4.0.zip"
|
||||
STM32_BARE_LIB_URL="https://github.com/google/stm32_bare_lib/archive/c07d611fb0af58450c5a3e0ab4d52b47f99bc82d.zip"
|
||||
10_LIB_URL="https://github.com/sifive/freedom-e-sdk/archive/baeeb8fd497a99b3c141d7494309ec2e64f19bdf.zip"
|
||||
SIFIVE_FE310_LIB_URL="https://github.com/sifive/freedom-e-sdk/archive/baeeb8fd497a99b3c141d7494309ec2e64f19bdf.zip"
|
||||
RISCV_TOOLCHAIN_URL="https://static.dev.sifive.com/dev-tools/riscv64-unknown-elf-gcc-20181030-x86_64-linux-ubuntu14.tar.gz"
|
||||
AM_SDK_URL="http://s3.asia.ambiqmicro.com/downloads/AmbiqSuite-Rel2.0.0.zip"
|
||||
AP3_URL="https://github.com/AmbiqMicro/TFLiteMicro_Apollo3/archive/dfbcef9a57276c087d95aab7cb234f1d4c9eaaba.zip"
|
||||
@ -101,7 +101,7 @@ patch_am_sdk() {
|
||||
|
||||
# Workaround for bug in 2.0.0 SDK, remove once that's fixed.
|
||||
sed -i -e 's/#ifndef AM_HAL_GPIO_H/#ifdef __cplusplus\nextern "C" {\n#endif\n#ifndef AM_HAL_GPIO_H/g' ${am_dir}/mcu/apollo3/hal/am_hal_gpio.h
|
||||
|
||||
|
||||
echo "Finished preparing Apollo3 files"
|
||||
}
|
||||
|
||||
@ -109,8 +109,8 @@ patch_kissfft() {
|
||||
sed -i -E "s@#ifdef FIXED_POINT@// Patched automatically by download_dependencies.sh so default is 16 bit.\n#ifndef FIXED_POINT\n#define FIXED_POINT (16)\n#endif\n// End patch.\n\n#ifdef FIXED_POINT@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h
|
||||
sed -i -E "s@#define KISS_FFT_MALLOC malloc@#define KISS_FFT_MALLOC(X) (void*)(0) /* Patched. */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h
|
||||
sed -i -E "s@#define KISS_FFT_FREE free@#define KISS_FFT_FREE(X) /* Patched. */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h
|
||||
sed -i -E "s@(fprintf.*\);)@/* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c
|
||||
sed -i -E "s@(exit.*\);)@return; /* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c
|
||||
sed -ir -E "s@(fprintf.*\);)@/* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c
|
||||
sed -ir -E "s@(exit.*\);)@return; /* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c
|
||||
echo "Finished patching kissfft"
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite Objective-C API.
|
||||
# TensorFlow Lite for Objective-C
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
@ -83,11 +83,13 @@ ios_unit_test(
|
||||
"notsan",
|
||||
"nomsan",
|
||||
],
|
||||
deps = [":TensorFlowLiteTestsLib"],
|
||||
deps = [
|
||||
":TestsLib",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TensorFlowLiteTestsLib",
|
||||
name = "TestsLib",
|
||||
testonly = 1,
|
||||
srcs = glob([
|
||||
"tests/*.m",
|
||||
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite Objective-C Library
|
||||
# TensorFlow Lite for Objective-C
|
||||
|
||||
[TensorFlow Lite](https://www.tensorflow.org/lite/) is TensorFlow's lightweight
|
||||
solution for Objective-C developers. It enables low-latency inference of
|
||||
@ -44,9 +44,11 @@ bazel test tensorflow/lite/experimental/objc:TensorFlowLiteTests
|
||||
|
||||
### Tulsi
|
||||
|
||||
Open the `TensorFlowLiteObjc.tulsiproj` using the Tulsi application on Mac or by
|
||||
running the following command in Terminal from the root source directory:
|
||||
Open the `TensorFlowLite.tulsiproj` using the
|
||||
[TulsiApp](https://github.com/bazelbuild/tulsi) or by running the
|
||||
[`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
|
||||
script from the root `tensorflow` directory:
|
||||
|
||||
```shell
|
||||
generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/objc/TensorFlowLiteObjc.tulsiproj:TensorFlowLiteObjC --outputfolder ~/path/to/xcodeproj
|
||||
generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
|
||||
```
|
||||
|
@ -15,7 +15,7 @@
|
||||
"//tensorflow/lite/experimental/objc:TensorFlowLite",
|
||||
"//tensorflow/lite/experimental/objc:TensorFlowLiteTests",
|
||||
],
|
||||
"projectName" : "TensorFlowLiteObjC",
|
||||
"projectName" : "TensorFlowLite",
|
||||
"optionSet" : {
|
||||
"LaunchActionPreActionScript" : {
|
||||
"p" : "$(inherited)"
|
@ -9,7 +9,7 @@
|
||||
},
|
||||
}
|
||||
},
|
||||
"projectName" : "TensorFlowLiteObjC",
|
||||
"projectName" : "TensorFlowLite",
|
||||
"packages" : [
|
||||
"tensorflow/lite/experimental/objc"
|
||||
],
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite for Swift.
|
||||
# TensorFlow Lite for Swift
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
@ -11,10 +11,6 @@ load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
|
||||
|
||||
MINIMUM_OS_VERSION = "9.0"
|
||||
|
||||
SWIFT_COPTS = [
|
||||
"-wmo",
|
||||
]
|
||||
|
||||
# Default tags for filtering targets. Targets in this file are restricted to Apple platforms.
|
||||
DEFAULT_TAGS = [
|
||||
"apple",
|
||||
@ -24,7 +20,6 @@ DEFAULT_TAGS = [
|
||||
swift_library(
|
||||
name = "TensorFlowLite",
|
||||
srcs = glob(["Sources/*.swift"]),
|
||||
copts = SWIFT_COPTS,
|
||||
module_name = "TensorFlowLite",
|
||||
tags = DEFAULT_TAGS,
|
||||
deps = [
|
||||
@ -42,31 +37,11 @@ ios_unit_test(
|
||||
"nomsan",
|
||||
"notsan",
|
||||
],
|
||||
deps = [":TensorFlowLiteTestsLib"],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "TensorFlowLiteTestsLib",
|
||||
testonly = 1,
|
||||
srcs = glob(["Tests/*.swift"]),
|
||||
copts = SWIFT_COPTS,
|
||||
tags = DEFAULT_TAGS,
|
||||
deps = [
|
||||
":TensorFlowLite",
|
||||
":TestResources",
|
||||
":TestsLib",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TestResources",
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
"//tensorflow/lite:testdata/add_quantized.bin",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
],
|
||||
tags = DEFAULT_TAGS,
|
||||
)
|
||||
|
||||
ios_application(
|
||||
name = "TensorFlowLiteApp",
|
||||
app_icons = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/**"]),
|
||||
@ -82,25 +57,50 @@ ios_application(
|
||||
"CoreGraphics",
|
||||
],
|
||||
tags = DEFAULT_TAGS,
|
||||
deps = [":TensorFlowLiteAppLib"],
|
||||
deps = [
|
||||
":AppLib",
|
||||
],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "TensorFlowLiteAppLib",
|
||||
name = "TestsLib",
|
||||
testonly = 1,
|
||||
srcs = glob(["Tests/*.swift"]),
|
||||
tags = DEFAULT_TAGS,
|
||||
deps = [
|
||||
":Resources",
|
||||
":TensorFlowLite",
|
||||
],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "AppLib",
|
||||
srcs = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/*.swift"]),
|
||||
module_name = "TensorFlowLiteAppLib",
|
||||
tags = DEFAULT_TAGS,
|
||||
deps = [
|
||||
":AppResources",
|
||||
":TensorFlowLite",
|
||||
":TensorFlowLiteAppResources",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "TensorFlowLiteAppResources",
|
||||
name = "Resources",
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
"//tensorflow/lite:testdata/add_quantized.bin",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
],
|
||||
tags = DEFAULT_TAGS,
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "AppResources",
|
||||
data = glob([
|
||||
"TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/*.storyboard",
|
||||
]),
|
||||
tags = DEFAULT_TAGS,
|
||||
deps = [":TestResources"],
|
||||
deps = [
|
||||
":Resources",
|
||||
],
|
||||
)
|
||||
|
@ -52,10 +52,12 @@ Note that `--swiftcopt=-enable-testing` is required for optimized builds (`-c op
|
||||
|
||||
### Tulsi
|
||||
|
||||
Open the `TensorFlowLite.tulsiproj` using the [TulsiApp](https://github.com/bazelbuild/tulsi) or by
|
||||
running the [`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
|
||||
script:
|
||||
Open the `TensorFlowLite.tulsiproj` using the
|
||||
[TulsiApp](https://github.com/bazelbuild/tulsi)
|
||||
or by running the
|
||||
[`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
|
||||
script from the root `tensorflow` directory:
|
||||
|
||||
```shell
|
||||
generate_xcodeproj.sh --genconfig tensorflow/lite/swift/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
|
||||
generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
|
||||
```
|
||||
|
@ -12,114 +12,108 @@ upper_tabs:
|
||||
lower_tabs:
|
||||
# Subsite tabs
|
||||
other:
|
||||
- name: Guide
|
||||
- name: "Guide"
|
||||
contents:
|
||||
- title: Overview
|
||||
path: /lite/overview
|
||||
- title: Developer guide
|
||||
path: /lite/devguide
|
||||
- title: Android demo app
|
||||
path: /lite/demo_android
|
||||
- title: iOS demo app
|
||||
path: /lite/demo_ios
|
||||
- break: true
|
||||
- title: TensorFlow Lite inference
|
||||
path: /lite/inference
|
||||
- title: Custom operators
|
||||
path: /lite/custom_operators
|
||||
- title: TensorFlow Lite ops versioning
|
||||
path: /lite/ops_versioning
|
||||
- title: TensorFlow Lite compatibility guide
|
||||
path: /lite/tf_ops_compatibility
|
||||
- title: TensorFlow Lite for iOS
|
||||
path: /lite/ios
|
||||
- title: TensorFlow Lite for Raspberry Pi
|
||||
path: /lite/rpi
|
||||
- title: "TensorFlow Lite guide"
|
||||
path: /lite/guide
|
||||
|
||||
- heading: TF Lite converter
|
||||
- title: Overview
|
||||
- heading: "Get started"
|
||||
- title: "Overview"
|
||||
path: /lite/guide/get_started
|
||||
- title: "Android quickstart"
|
||||
path: /lite/guide/android
|
||||
- title: "iOS quickstart"
|
||||
path: /lite/guide/ios
|
||||
- title: "TensorFlow Lite FAQ"
|
||||
path: /lite/guide/faq
|
||||
|
||||
- heading: "Convert a model"
|
||||
- title: "TensorFlow Lite converter"
|
||||
path: /lite/convert/
|
||||
- title: Python API guide
|
||||
path: /lite/convert/python_api
|
||||
- title: Command line examples
|
||||
- title: "Command line examples"
|
||||
path: /lite/convert/cmdline_examples
|
||||
- title: Command line reference
|
||||
- title: "Command line reference"
|
||||
path: /lite/convert/cmdline_reference
|
||||
- title: "Python API"
|
||||
path: /lite/convert/python_api
|
||||
|
||||
- heading: Performance
|
||||
- title: Best practices
|
||||
- heading: "Inference"
|
||||
- title: "Overview"
|
||||
path: /lite/guide/inference
|
||||
- title: "Custom operators"
|
||||
path: /lite/guide/ops_custom
|
||||
- title: "Operator versions"
|
||||
path: /lite/guide/ops_version
|
||||
- title: "Operator compatibility"
|
||||
path: /lite/guide/ops_compatibility
|
||||
- title: "Select operators from TensorFlow"
|
||||
path: /lite/guide/ops_select
|
||||
- title: "List of hosted models"
|
||||
path: /lite/guide/hosted_models
|
||||
|
||||
- heading: "Performance"
|
||||
- title: "Best practices"
|
||||
path: /lite/performance/best_practices
|
||||
- title: Benchmarks
|
||||
- title: "Benchmarks"
|
||||
path: /lite/performance/benchmarks
|
||||
- title: Model optimization
|
||||
- title: "Model optimization"
|
||||
path: /lite/performance/model_optimization
|
||||
- title: Post-training quantization
|
||||
- title: "Post-training quantization"
|
||||
path: /lite/performance/post_training_quantization
|
||||
- title: Post-training quantization example
|
||||
- title: "Post-training quantization example"
|
||||
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb
|
||||
status: external
|
||||
- title: GPU delegate
|
||||
- title: "GPU delegate"
|
||||
path: /lite/performance/gpu
|
||||
- title: Advanced GPU
|
||||
- title: "Advanced GPU"
|
||||
path: /lite/performance/gpu_advanced
|
||||
|
||||
- title: TF Mobile
|
||||
style: accordion
|
||||
status: deprecated
|
||||
section:
|
||||
- title: Overview
|
||||
path: /lite/tfmobile/
|
||||
- title: Building TensorFlow on Android
|
||||
path: /lite/tfmobile/android_build
|
||||
- title: Building TensorFlow on iOS
|
||||
path: /lite/tfmobile/ios_build
|
||||
- title: Integrating TensorFlow libraries
|
||||
path: /lite/tfmobile/linking_libs
|
||||
- title: Preparing models for mobile deployment
|
||||
path: /lite/tfmobile/prepare_models
|
||||
- title: Optimizing for mobile
|
||||
path: /lite/tfmobile/optimizing
|
||||
- heading: "Build TensorFlow Lite"
|
||||
- title: "Build for iOS"
|
||||
path: /lite/guide/build_ios
|
||||
- title: "Build for ARM64"
|
||||
path: /lite/guide/build_arm64
|
||||
- title: "Build for Raspberry Pi"
|
||||
path: /lite/guide/build_rpi
|
||||
|
||||
- name: Examples
|
||||
- name: "Examples"
|
||||
contents:
|
||||
- title: Examples
|
||||
- title: "Examples"
|
||||
path: /lite/examples
|
||||
|
||||
- name: Models
|
||||
- name: "Models"
|
||||
contents:
|
||||
- title: Overview
|
||||
- title: "Overview"
|
||||
path: /lite/models/
|
||||
- title: Hosted models
|
||||
path: /lite/models/hosted
|
||||
- title: Image classification
|
||||
- title: "Image classification"
|
||||
section:
|
||||
- title: Overview
|
||||
- title: "Overview"
|
||||
path: /lite/models/image_classification/overview
|
||||
- title: Android
|
||||
- title: "Android"
|
||||
path: /lite/models/image_classification/android
|
||||
- title: iOS
|
||||
- title: "iOS"
|
||||
path: /lite/models/image_classification/ios
|
||||
- title: Object detection
|
||||
- title: "Object detection"
|
||||
section:
|
||||
- title: Overview
|
||||
- title: "Overview"
|
||||
path: /lite/models/object_detection/overview
|
||||
- title: Pose estimation
|
||||
- title: "Pose estimation"
|
||||
section:
|
||||
- title: Overview
|
||||
- title: "Overview"
|
||||
path: /lite/models/pose_estimation/overview
|
||||
- title: Segmentation
|
||||
- title: "Segmentation"
|
||||
section:
|
||||
- title: Overview
|
||||
- title: "Overview"
|
||||
path: /lite/models/segmentation/overview
|
||||
- title: Smart reply
|
||||
- title: "Smart reply"
|
||||
section:
|
||||
- title: Overview
|
||||
- title: "Overview"
|
||||
path: /lite/models/smart_reply/overview
|
||||
|
||||
- name: API
|
||||
- name: "API"
|
||||
skip_translation: true
|
||||
contents:
|
||||
- title: API
|
||||
- title: "API"
|
||||
path: /api_docs/python/tf/lite
|
||||
|
||||
- include: /api_docs/_upper_tabs_api.yaml
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Converter command-line examples
|
||||
# Converter command line examples
|
||||
|
||||
This page shows how to use the TensorFlow Lite Converter in the command line.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Converter command-line reference
|
||||
# Converter command line reference
|
||||
|
||||
This page is complete reference of command-line flags used by the TensorFlow
|
||||
Lite Converter's command line starting from TensorFlow 1.9 up until the most
|
||||
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite Converter
|
||||
# TensorFlow Lite converter
|
||||
|
||||
TensorFlow Lite uses the optimized
|
||||
[FlatBuffer](https://google.github.io/flatbuffers/) format to represent graphs.
|
||||
|
@ -1,5 +1,4 @@
|
||||
|
||||
# Android Demo App
|
||||
# Android quickstart
|
||||
|
||||
An example Android application using TensorFLow Lite is available
|
||||
[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo).
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite for generic ARM64 boards
|
||||
# Build TensorFlow Lite for ARM64 boards
|
||||
|
||||
## Cross compiling
|
||||
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite for Raspberry Pi
|
||||
# Build TensorFlow Lite for Raspberry Pi
|
||||
|
||||
## Cross compiling
|
||||
|
@ -1,4 +1,4 @@
|
||||
# TF Lite Developer Guide
|
||||
# Get started with TensorFlow Lite
|
||||
|
||||
Using a TensorFlow Lite model in your mobile app requires multiple
|
||||
considerations: you must choose a pre-trained or custom model, convert the model
|
@ -1,5 +1,5 @@
|
||||
|
||||
# Introduction to TensorFlow Lite
|
||||
# TensorFlow Lite guide
|
||||
|
||||
TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded
|
||||
devices. It enables on-device machine learning inference with low latency and a
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite Inference
|
||||
# TensorFlow Lite inference
|
||||
|
||||
[TOC]
|
||||
|
@ -1,4 +1,4 @@
|
||||
# iOS Demo App
|
||||
# iOS quickstart
|
||||
|
||||
This tutorial provides a simple iOS mobile application to classify images using
|
||||
the iOS device camera. In this tutorial, you will download the demo application
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite & TensorFlow Compatibility Guide
|
||||
# TensorFlow Lite and TensorFlow operator compatibility
|
||||
|
||||
TensorFlow Lite supports a number of TensorFlow operations used in common
|
||||
inference models. As they are processed by the TensorFlow Lite Optimizing
|
@ -1,4 +1,6 @@
|
||||
# [Experimental] Using TensorFlow Lite with select TensorFlow ops
|
||||
# Select TensorFlow operators to use in TensorFlow Lite
|
||||
|
||||
Caution: This feature is experimental.
|
||||
|
||||
The TensorFlow Lite builtin op library has grown rapidly, and will continue to
|
||||
grow, but there remains a long tail of TensorFlow ops that are not yet natively
|
||||
@ -196,9 +198,7 @@ Python support is actively under development.
|
||||
|
||||
When using a mixture of both builtin and select TensorFlow ops, all of the same
|
||||
TensorFlow Lite optimizations and optimized builtin kernels will be be available
|
||||
and usable with the converted model. For the TensorFlow ops, performance should
|
||||
generally be comparable to that of
|
||||
[TensorFlow Mobile](https://www.tensorflow.org/lite/tfmobile/).
|
||||
and usable with the converted model.
|
||||
|
||||
The following table describes the average time taken to run inference on
|
||||
MobileNet on a Pixel 2. The listed times are an average of 100 runs. These
|
@ -1,5 +1,4 @@
|
||||
|
||||
# TensorFlow Lite Ops Versioning
|
||||
# TensorFlow Lite operator versions
|
||||
|
||||
This document describes TensorFlow Lite's op versioning schema. Op
|
||||
versioning enables developers to add new functionalities and parameters into
|
@ -1,5 +1,4 @@
|
||||
|
||||
# Performance
|
||||
# Performance benchmarks
|
||||
|
||||
This document lists TensorFlow Lite performance benchmarks when running well
|
||||
known models on some Android and iOS devices.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# TensorFlow Lite GPU Delegate Tutorial
|
||||
# TensorFlow Lite GPU delegate
|
||||
|
||||
[TensorFlow Lite](https://www.tensorflow.org/lite) supports several hardware
|
||||
accelerators. This document describes how to preview the experimental GPU backend using the
|
||||
|
@ -1,195 +0,0 @@
|
||||
# Building TensorFlow on Android
|
||||
|
||||
Warning: We expect to deprecate TensorFlow Mobile in early 2019
|
||||
|
||||
<div class="caution">
|
||||
<p>
|
||||
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
|
||||
working hard to close the feature gap between TensorFlow Mobile and
|
||||
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
|
||||
will give ample notice to our users when we get to that point and will
|
||||
provide help and support to ensure easy migrations.
|
||||
</p>
|
||||
<p>
|
||||
In the meantime, please use TensorFlow Lite. If you have a feature request,
|
||||
such as a missing op, please post to our <a
|
||||
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
To get you started working with TensorFlow on Android, we'll walk through two
|
||||
ways to build our TensorFlow mobile demos and deploying them on an Android
|
||||
device. The first is Android Studio, which lets you build and deploy in an
|
||||
IDE. The second is building with Bazel and deploying with ADB on the command
|
||||
line.
|
||||
|
||||
Why choose one or the other of these methods?
|
||||
|
||||
The simplest way to use TensorFlow on Android is to use Android Studio. If you
|
||||
aren't planning to customize your TensorFlow build at all, or if you want to use
|
||||
Android Studio's editor and other features to build an app and just want to add
|
||||
TensorFlow to it, we recommend using Android Studio.
|
||||
|
||||
If you are using custom ops, or have some other reason to build TensorFlow from
|
||||
scratch, scroll down and see our instructions
|
||||
for [building the demo with Bazel](#build_the_demo_using_bazel).
|
||||
|
||||
## Build the demo using Android Studio
|
||||
|
||||
**Prerequisites**
|
||||
|
||||
If you haven't already, do the following two things:
|
||||
|
||||
- Install [Android Studio](https://developer.android.com/studio/index.html),
|
||||
following the instructions on their website.
|
||||
|
||||
- Clone the TensorFlow repository from GitHub:
|
||||
|
||||
git clone https://github.com/tensorflow/tensorflow
|
||||
|
||||
**Building**
|
||||
|
||||
1. Open Android Studio, and from the Welcome screen, select **Open an existing
|
||||
Android Studio project**.
|
||||
|
||||
2. From the **Open File or Project** window that appears, navigate to and select
|
||||
the `tensorflow/examples/android` directory from wherever you cloned the
|
||||
TensorFlow GitHub repo. Click OK.
|
||||
|
||||
If it asks you to do a Gradle Sync, click OK.
|
||||
|
||||
You may also need to install various platforms and tools, if you get
|
||||
errors like "Failed to find target with hash string 'android-23' and similar.
|
||||
|
||||
3. Open the `build.gradle` file (you can go to **1:Project** in the side panel
|
||||
and find it under the **Gradle Scripts** zippy under **Android**). Look for
|
||||
the `nativeBuildSystem` variable and set it to `none` if it isn't already:
|
||||
|
||||
// set to 'bazel', 'cmake', 'makefile', 'none'
|
||||
def nativeBuildSystem = 'none'
|
||||
|
||||
4. Click the *Run* button (the green arrow) or select *Run > Run 'android'* from the
|
||||
top menu. You may need to rebuild the project using *Build > Rebuild Project*.
|
||||
|
||||
If it asks you to use Instant Run, click **Proceed Without Instant Run**.
|
||||
|
||||
Also, you need to have an Android device plugged in with developer options
|
||||
enabled at this
|
||||
point. See [here](https://developer.android.com/studio/run/device.html) for
|
||||
more details on setting up developer devices.
|
||||
|
||||
This installs three apps on your phone that are all part of the TensorFlow
|
||||
Demo. See [Android Sample Apps](#android_sample_apps) for more information about
|
||||
them.
|
||||
|
||||
## Adding TensorFlow to your apps using Android Studio
|
||||
|
||||
To add TensorFlow to your own apps on Android, the simplest way is to add the
|
||||
following lines to your Gradle build file:
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'org.tensorflow:tensorflow-android:+'
|
||||
}
|
||||
|
||||
This automatically downloads the latest stable version of TensorFlow as an AAR
|
||||
and installs it in your project.
|
||||
|
||||
## Build the demo using Bazel
|
||||
|
||||
Another way to use TensorFlow on Android is to build an APK
|
||||
using [Bazel](https://bazel.build/) and load it onto your device
|
||||
using [ADB](https://developer.android.com/studio/command-line/adb.html). This
|
||||
requires some knowledge of build systems and Android developer tools, but we'll
|
||||
guide you through the basics here.
|
||||
|
||||
- First, follow our instructions for
|
||||
<a href="http://www.tensorflow.org/install/source">installing from sources</a>.
|
||||
This will also guide you through installing Bazel and cloning the
|
||||
TensorFlow code.
|
||||
|
||||
- Download the Android [SDK](https://developer.android.com/studio/index.html)
|
||||
and [NDK](https://developer.android.com/ndk/downloads/index.html) if you do
|
||||
not already have them. You need at least version 12b of the NDK, and 23 of the
|
||||
SDK.
|
||||
|
||||
- In your copy of the TensorFlow source, update the
|
||||
[WORKSPACE](https://github.com/tensorflow/tensorflow/blob/master/WORKSPACE)
|
||||
file with the location of your SDK and NDK, where it says <PATH_TO_NDK>
|
||||
and <PATH_TO_SDK>.
|
||||
|
||||
- Run Bazel to build the demo APK:
|
||||
|
||||
bazel build -c opt //tensorflow/examples/android:tensorflow_demo
|
||||
|
||||
- Use [ADB](https://developer.android.com/studio/command-line/adb.html#move) to
|
||||
install the APK onto your device:
|
||||
|
||||
adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk
|
||||
|
||||
Note: In general when compiling for Android with Bazel you need
|
||||
`--config=android` on the Bazel command line, though in this case this
|
||||
particular example is Android-only, so you don't need it here.
|
||||
|
||||
This installs three apps on your phone that are all part of the TensorFlow
|
||||
Demo. See [Android Sample Apps](#android_sample_apps) for more information about
|
||||
them.
|
||||
|
||||
## Android Sample Apps
|
||||
|
||||
The
|
||||
[Android example code](https://www.tensorflow.org/code/tensorflow/examples/android/) is
|
||||
a single project that builds and installs three sample apps which all use the
|
||||
same underlying code. The sample apps all take video input from a phone's
|
||||
camera:
|
||||
|
||||
- **TF Classify** uses the Inception v3 model to label the objects it’s pointed
|
||||
at with classes from Imagenet. There are only 1,000 categories in Imagenet,
|
||||
which misses most everyday objects and includes many things you’re unlikely to
|
||||
encounter often in real life, so the results can often be quite amusing. For
|
||||
example there’s no ‘person’ category, so instead it will often guess things it
|
||||
does know that are often associated with pictures of people, like a seat belt
|
||||
or an oxygen mask. If you do want to customize this example to recognize
|
||||
objects you care about, you can use
|
||||
the
|
||||
[TensorFlow for Poets codelab](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0) as
|
||||
an example for how to train a model based on your own data.
|
||||
|
||||
- **TF Detect** uses a multibox model to try to draw bounding boxes around the
|
||||
locations of people in the camera. These boxes are annotated with the
|
||||
confidence for each detection result. Results will not be perfect, as this
|
||||
kind of object detection is still an active research topic. The demo also
|
||||
includes optical tracking for when objects move between frames, which runs
|
||||
more frequently than the TensorFlow inference. This improves the user
|
||||
experience since the apparent frame rate is faster, but it also gives the
|
||||
ability to estimate which boxes refer to the same object between frames, which
|
||||
is important for counting objects over time.
|
||||
|
||||
- **TF Stylize** implements a real-time style transfer algorithm on the camera
|
||||
feed. You can select which styles to use and mix between them using the
|
||||
palette at the bottom of the screen, and also switch out the resolution of the
|
||||
processing to go higher or lower rez.
|
||||
|
||||
When you build and install the demo, you'll see three app icons on your phone,
|
||||
one for each of the demos. Tapping on them should open up the app and let you
|
||||
explore what they do. You can enable profiling statistics on-screen by tapping
|
||||
the volume up button while they’re running.
|
||||
|
||||
### Android Inference Library
|
||||
|
||||
Because Android apps need to be written in Java, and core TensorFlow is in C++,
|
||||
TensorFlow has a JNI library to interface between the two. Its interface is aimed
|
||||
only at inference, so it provides the ability to load a graph, set up inputs,
|
||||
and run the model to calculate particular outputs. You can see the full
|
||||
documentation for the minimal set of methods in
|
||||
[TensorFlowInferenceInterface.java](https://www.tensorflow.org/code/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java)
|
||||
|
||||
The demos applications use this interface, so they’re a good place to look for
|
||||
example usage. You can download prebuilt binary jars
|
||||
at
|
||||
[ci.tensorflow.org](https://ci.tensorflow.org/view/Nightly/job/nightly-android/).
|
@ -1,298 +0,0 @@
|
||||
# Overview
|
||||
|
||||
Warning: We expect to deprecate TensorFlow Mobile in early 2019
|
||||
|
||||
<div class="caution">
|
||||
<p>
|
||||
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
|
||||
working hard to close the feature gap between TensorFlow Mobile and
|
||||
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
|
||||
will give ample notice to our users when we get to that point and will
|
||||
provide help and support to ensure easy migrations.
|
||||
</p>
|
||||
<p>
|
||||
In the meantime, please use TensorFlow Lite. If you have a feature request,
|
||||
such as a missing op, please post to our <a
|
||||
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
TensorFlow was designed to be a good deep learning solution for mobile
|
||||
platforms. Currently we have two solutions for deploying machine learning
|
||||
applications on mobile and embedded devices: TensorFlow for Mobile and
|
||||
<a href="../../lite">TensorFlow Lite</a>.
|
||||
|
||||
## TensorFlow Lite versus TensorFlow Mobile
|
||||
|
||||
Here are a few of the differences between the two:
|
||||
|
||||
- TensorFlow Lite is an evolution of TensorFlow Mobile. In most cases, apps
|
||||
developed with TensorFlow Lite will have a smaller binary size, fewer
|
||||
dependencies, and better performance.
|
||||
|
||||
- TensorFlow Lite is in developer preview, so not all use cases are covered yet.
|
||||
We expect you to use TensorFlow Mobile to cover production cases.
|
||||
|
||||
- TensorFlow Lite supports only a limited set of operators, so not all models
|
||||
will work on it by default. TensorFlow for Mobile has a fuller set of
|
||||
supported functionality.
|
||||
|
||||
TensorFlow Lite provides better performance and a small binary size on mobile
|
||||
platforms as well as the ability to leverage hardware acceleration if available
|
||||
on their platforms. In addition, it has many fewer dependencies so it can be
|
||||
built and hosted on simpler, more constrained device scenarios. TensorFlow Lite
|
||||
also allows targeting accelerators through the [Neural Networks
|
||||
API](https://developer.android.com/ndk/guides/neuralnetworks/index.html).
|
||||
|
||||
TensorFlow Lite currently has coverage for a limited set of operators. While
|
||||
TensorFlow for Mobile supports only a constrained set of ops by default, in
|
||||
principle if you use an arbitrary operator in TensorFlow, it can be customized
|
||||
to build that kernel. Thus use cases which are not currently supported by
|
||||
TensorFlow Lite should continue to use TensorFlow for Mobile. As TensorFlow Lite
|
||||
evolves, it will gain additional operators, and the decision will be easier to
|
||||
make.
|
||||
|
||||
|
||||
## Introduction to TensorFlow Mobile
|
||||
|
||||
TensorFlow was designed from the ground up to be a good deep learning solution
|
||||
for mobile platforms like Android and iOS. This mobile guide should help you
|
||||
understand how machine learning can work on mobile platforms and how to
|
||||
integrate TensorFlow into your mobile apps effectively and efficiently.
|
||||
|
||||
## About this Guide
|
||||
|
||||
This guide is aimed at developers who have a TensorFlow model that’s
|
||||
successfully working in a desktop environment, who want to integrate it into
|
||||
a mobile application, and cannot use TensorFlow Lite. Here are the
|
||||
main challenges you’ll face during that process:
|
||||
|
||||
- Understanding how to use Tensorflow for mobile.
|
||||
- Building TensorFlow for your platform.
|
||||
- Integrating the TensorFlow library into your application.
|
||||
- Preparing your model file for mobile deployment.
|
||||
- Optimizing for latency, RAM usage, model file size, and binary size.
|
||||
|
||||
## Common use cases for mobile machine learning
|
||||
|
||||
**Why run TensorFlow on mobile?**
|
||||
|
||||
Traditionally, deep learning has been associated with data centers and giant
|
||||
clusters of high-powered GPU machines. However, it can be very expensive and
|
||||
time-consuming to send all of the data a device has access to across a network
|
||||
connection. Running on mobile makes it possible to deliver very interactive
|
||||
applications in a way that’s not possible when you have to wait for a network
|
||||
round trip.
|
||||
|
||||
Here are some common use cases for on-device deep learning:
|
||||
|
||||
### Speech Recognition
|
||||
|
||||
There are a lot of interesting applications that can be built with a
|
||||
speech-driven interface, and many of these require on-device processing. Most of
|
||||
the time a user isn’t giving commands, and so streaming audio continuously to a
|
||||
remote server would be a waste of bandwidth, since it would mostly be silence or
|
||||
background noises. To solve this problem it’s common to have a small neural
|
||||
network running on-device
|
||||
[listening out for a particular keyword](../tutorials/sequences/audio_recognition).
|
||||
Once that keyword has been spotted, the rest of the
|
||||
conversation can be transmitted over to the server for further processing if
|
||||
more computing power is needed.
|
||||
|
||||
### Image Recognition
|
||||
|
||||
It can be very useful for a mobile app to be able to make sense of a camera
|
||||
image. If your users are taking photos, recognizing what’s in them can help your
|
||||
camera apps apply appropriate filters, or label the photos so they’re easily
|
||||
findable. It’s important for embedded applications too, since you can use image
|
||||
sensors to detect all sorts of interesting conditions, whether it’s spotting
|
||||
endangered animals in the wild
|
||||
or
|
||||
[reporting how late your train is running](https://svds.com/tensorflow-image-recognition-raspberry-pi/).
|
||||
|
||||
TensorFlow comes with several examples of recognizing the types of objects
|
||||
inside images along with a variety of different pre-trained models, and they can
|
||||
all be run on mobile devices. You can try out
|
||||
our
|
||||
[Tensorflow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0) and
|
||||
[Tensorflow for Poets 2: Optimize for Mobile](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/index.html#0) codelabs to
|
||||
see how to take a pretrained model and run some very fast and lightweight
|
||||
training to teach it to recognize specific objects, and then optimize it to
|
||||
run on mobile.
|
||||
|
||||
### Object Localization
|
||||
|
||||
Sometimes it’s important to know where objects are in an image as well as what
|
||||
they are. There are lots of augmented reality use cases that could benefit a
|
||||
mobile app, such as guiding users to the right component when offering them
|
||||
help fixing their wireless network or providing informative overlays on top of
|
||||
landscape features. Embedded applications often need to count objects that are
|
||||
passing by them, whether it’s pests in a field of crops, or people, cars and
|
||||
bikes going past a street lamp.
|
||||
|
||||
TensorFlow offers a pretrained model for drawing bounding boxes around people
|
||||
detected in images, together with tracking code to follow them over time. The
|
||||
tracking is especially important for applications where you’re trying to count
|
||||
how many objects are present over time, since it gives you a good idea when a
|
||||
new object enters or leaves the scene. We have some sample code for this
|
||||
available for Android [on
|
||||
GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android),
|
||||
and also a [more general object detection
|
||||
model](https://github.com/tensorflow/models/tree/master/research/object_detection/README.md)
|
||||
available as well.
|
||||
|
||||
### Gesture Recognition
|
||||
|
||||
It can be useful to be able to control applications with hand or other
|
||||
gestures, either recognized from images or through analyzing accelerometer
|
||||
sensor data. Creating those models is beyond the scope of this guide, but
|
||||
TensorFlow is an effective way of deploying them.
|
||||
|
||||
### Optical Character Recognition
|
||||
|
||||
Google Translate’s live camera view is a great example of how effective
|
||||
interactive on-device detection of text can be.
|
||||
|
||||
<div class="video-wrapper">
|
||||
<iframe class="devsite-embedded-youtube-video" data-video-id="06olHmcJjS0"
|
||||
data-autohide="1" data-showinfo="0" frameborder="0" allowfullscreen>
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
There are multiple steps involved in recognizing text in images. You first have
|
||||
to identify the areas where the text is present, which is a variation on the
|
||||
object localization problem, and can be solved with similar techniques. Once you
|
||||
have an area of text, you then need to interpret it as letters, and then use a
|
||||
language model to help guess what words they represent. The simplest way to
|
||||
estimate what letters are present is to segment the line of text into individual
|
||||
letters, and then apply a simple neural network to the bounding box of each. You
|
||||
can get good results with the kind of models used for MNIST, which you can find
|
||||
in TensorFlow’s tutorials, though you may want a higher-resolution input. A
|
||||
more advanced alternative is to use an LSTM model to process a whole line of
|
||||
text at once, with the model itself handling the segmentation into different
|
||||
characters.
|
||||
|
||||
### Translation
|
||||
|
||||
Translating from one language to another quickly and accurately, even if you
|
||||
don’t have a network connection, is an important use case. Deep networks are
|
||||
very effective at this sort of task, and you can find descriptions of a lot of
|
||||
different models in the literature. Often these are sequence-to-sequence
|
||||
recurrent models where you’re able to run a single graph to do the whole
|
||||
translation, without needing to run separate parsing stages.
|
||||
|
||||
### Text Classification
|
||||
|
||||
If you want to suggest relevant prompts to users based on what they’re typing or
|
||||
reading, it can be very useful to understand the meaning of the text. This is
|
||||
where text classification comes in. Text classification is an umbrella term
|
||||
that covers everything from sentiment analysis to topic discovery. You’re likely
|
||||
to have your own categories or labels that you want to apply, so the best place
|
||||
to start is with an example
|
||||
like
|
||||
[Skip-Thoughts](https://github.com/tensorflow/models/tree/master/research/skip_thoughts/),
|
||||
and then train on your own examples.
|
||||
|
||||
### Voice Synthesis
|
||||
|
||||
A synthesized voice can be a great way of giving users feedback or aiding
|
||||
accessibility, and recent advances such as
|
||||
[WaveNet](https://deepmind.com/blog/wavenet-generative-model-raw-audio/) show
|
||||
that deep learning can offer very natural-sounding speech.
|
||||
|
||||
## Mobile machine learning and the cloud
|
||||
|
||||
These examples of use cases give an idea of how on-device networks can
|
||||
complement cloud services. Cloud has a great deal of computing power in a
|
||||
controlled environment, but running on devices can offer higher interactivity.
|
||||
In situations where the cloud is unavailable, or your cloud capacity is limited,
|
||||
you can provide an offline experience, or reduce cloud workload by processing
|
||||
easy cases on device.
|
||||
|
||||
Doing on-device computation can also signal when it's time to switch to working
|
||||
on the cloud. A good example of this is hotword detection in speech. Since
|
||||
devices are able to constantly listen out for the keywords, this then triggers a
|
||||
lot of traffic to cloud-based speech recognition once one is recognized. Without
|
||||
the on-device component, the whole application wouldn’t be feasible, and this
|
||||
pattern exists across several other applications as well. Recognizing that some
|
||||
sensor input is interesting enough for further processing makes a lot of
|
||||
interesting products possible.
|
||||
|
||||
## What hardware and software should you have?
|
||||
|
||||
TensorFlow runs on Ubuntu Linux, Windows 10, and OS X. For a list of all
|
||||
supported operating systems and instructions to install TensorFlow, see
|
||||
<a href="https://www.tensorflow.org/install">Installing Tensorflow</a>.
|
||||
|
||||
Note that some of the sample code we provide for mobile TensorFlow requires you
|
||||
to compile TensorFlow from source, so you’ll need more than just `pip install`
|
||||
to work through all the sample code.
|
||||
|
||||
To try out the mobile examples, you’ll need a device set up for development,
|
||||
using
|
||||
either [Android Studio](https://developer.android.com/studio/install.html),
|
||||
or [XCode](https://developer.apple.com/xcode/) if you're developing for iOS.
|
||||
|
||||
## What should you do before you get started?
|
||||
|
||||
Before thinking about how to get your solution on mobile:
|
||||
|
||||
1. Determine whether your problem is solvable by mobile machine learning
|
||||
2. Create a labelled dataset to define your problem
|
||||
3. Pick an effective model for the problem
|
||||
|
||||
We'll discuss these in more detail below.
|
||||
|
||||
### Is your problem solvable by mobile machine learning?
|
||||
|
||||
Once you have an idea of the problem you want to solve, you need to make a plan
|
||||
of how to build your solution. The most important first step is making sure that
|
||||
your problem is actually solvable, and the best way to do that is to mock it up
|
||||
using humans in the loop.
|
||||
|
||||
For example, if you want to drive a robot toy car using voice commands, try
|
||||
recording some audio from the device and listen back to it to see if you can
|
||||
make sense of what’s being said. Often you’ll find there are problems in the
|
||||
capture process, such as the motor drowning out speech or not being able to hear
|
||||
at a distance, and you should tackle these problems before investing in the
|
||||
modeling process.
|
||||
|
||||
Another example would be giving photos taken from your app to people see if they
|
||||
can classify what’s in them, in the way you’re looking for. If they can’t do
|
||||
that (for example, trying to estimate calories in food from photos may be
|
||||
impossible because all white soups look the same), then you’ll need to redesign
|
||||
your experience to cope with that. A good rule of thumb is that if a human can’t
|
||||
handle the task then it will be difficult to train a computer to do better.
|
||||
|
||||
### Create a labelled dataset
|
||||
|
||||
After you’ve solved any fundamental issues with your use case, you need to
|
||||
create a labeled dataset to define what problem you’re trying to solve. This
|
||||
step is extremely important, more than picking which model to use. You want it
|
||||
to be as representative as possible of your actual use case, since the model
|
||||
will only be effective at the task you teach it. It’s also worth investing in
|
||||
tools to make labeling the data as efficient and accurate as possible. For
|
||||
example, if you’re able to switch from having to click a button on a web
|
||||
interface to simple keyboard shortcuts, you may be able to speed up the
|
||||
generation process a lot. You should also start by doing the initial labeling
|
||||
yourself, so you can learn about the difficulties and likely errors, and
|
||||
possibly change your labeling or data capture process to avoid them. Once you
|
||||
and your team are able to consistently label examples (that is once you
|
||||
generally agree on the same labels for most examples), you can then try and
|
||||
capture your knowledge in a manual and teach external raters how to run the same
|
||||
process.
|
||||
|
||||
### Pick an effective model
|
||||
|
||||
The next step is to pick an effective model to use. You might be able to avoid
|
||||
training a model from scratch if someone else has already implemented a model
|
||||
similar to what you need; we have a repository of models implemented in
|
||||
TensorFlow [on GitHub](https://github.com/tensorflow/models) that you can look
|
||||
through. Lean towards the simplest model you can find, and try to get started as
|
||||
soon as you have even a small amount of labelled data, since you’ll get the best
|
||||
results when you’re able to iterate quickly. The shorter the time it takes to
|
||||
try training a model and running it in its real application, the better overall
|
||||
results you’ll see. It’s common for an algorithm to get great training accuracy
|
||||
numbers but then fail to be useful within a real application because there’s a
|
||||
mismatch between the dataset and real usage. Prototype end-to-end usage as soon
|
||||
as possible to create a consistent user experience.
|
@ -1,124 +0,0 @@
|
||||
# Building TensorFlow on iOS
|
||||
|
||||
Warning: We expect to deprecate TensorFlow Mobile in early 2019
|
||||
|
||||
<div class="caution">
|
||||
<p>
|
||||
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
|
||||
working hard to close the feature gap between TensorFlow Mobile and
|
||||
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
|
||||
will give ample notice to our users when we get to that point and will
|
||||
provide help and support to ensure easy migrations.
|
||||
</p>
|
||||
<p>
|
||||
In the meantime, please use TensorFlow Lite. If you have a feature request,
|
||||
such as a missing op, please post to our <a
|
||||
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
## Using CocoaPods
|
||||
|
||||
The simplest way to get started with TensorFlow on iOS is using the CocoaPods
|
||||
package management system. You can add the `TensorFlow-experimental` pod to your
|
||||
Podfile, which installs a universal binary framework. This makes it easy to get
|
||||
started but has the disadvantage of being hard to customize, which is important
|
||||
in case you want to shrink your binary size. If you do need the ability to
|
||||
customize your libraries, see later sections on how to do that.
|
||||
|
||||
## Creating your own app
|
||||
|
||||
If you'd like to add TensorFlow capabilities to your own app, do the following:
|
||||
|
||||
- Create your own app or load your already-created app in XCode.
|
||||
|
||||
- Add a file named Podfile at the project root directory with the following content:
|
||||
|
||||
target 'YourProjectName'
|
||||
pod 'TensorFlow-experimental'
|
||||
|
||||
- Run `pod install` to download and install the `TensorFlow-experimental` pod.
|
||||
|
||||
- Open `YourProjectName.xcworkspace` and add your code.
|
||||
|
||||
- In your app's **Build Settings**, make sure to add `$(inherited)` to the
|
||||
**Other Linker Flags**, and **Header Search Paths** sections.
|
||||
|
||||
## Running the Samples
|
||||
|
||||
You'll need Xcode 7.3 or later to run our iOS samples.
|
||||
|
||||
There are currently three examples: simple, benchmark, and camera. For now, you
|
||||
can download the sample code by cloning the main tensorflow repository (we are
|
||||
planning to make the samples available as a separate repository later).
|
||||
|
||||
From the root of the tensorflow folder, download [Inception
|
||||
v1](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip),
|
||||
and extract the label and graph files into the data folders inside both the
|
||||
simple and camera examples using these steps:
|
||||
|
||||
mkdir -p ~/graphs
|
||||
curl -o ~/graphs/inception5h.zip \
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip \
|
||||
&& unzip ~/graphs/inception5h.zip -d ~/graphs/inception5h
|
||||
cp ~/graphs/inception5h/* tensorflow/examples/ios/benchmark/data/
|
||||
cp ~/graphs/inception5h/* tensorflow/examples/ios/camera/data/
|
||||
cp ~/graphs/inception5h/* tensorflow/examples/ios/simple/data/
|
||||
|
||||
Change into one of the sample directories, download the
|
||||
[Tensorflow-experimental](https://cocoapods.org/pods/TensorFlow-experimental)
|
||||
pod, and open the Xcode workspace. Note that installing the pod can take a long
|
||||
time since it is big (~450MB). If you want to run the simple example, then:
|
||||
|
||||
cd tensorflow/examples/ios/simple
|
||||
pod install
|
||||
open tf_simple_example.xcworkspace # note .xcworkspace, not .xcodeproj
|
||||
# this is created by pod install
|
||||
|
||||
Run the simple app in the XCode simulator. You should see a single-screen app
|
||||
with a **Run Model** button. Tap that, and you should see some debug output
|
||||
appear below indicating that the example Grace Hopper image in directory data
|
||||
has been analyzed, with a military uniform recognized.
|
||||
|
||||
Run the other samples using the same process. The camera example requires a real
|
||||
device connected. Once you build and run that, you should get a live camera view
|
||||
that you can point at objects to get real-time recognition results.
|
||||
|
||||
### iOS Example details
|
||||
|
||||
There are three demo applications for iOS, all defined in Xcode projects inside
|
||||
[tensorflow/examples/ios](https://www.tensorflow.org/code/tensorflow/examples/ios/).
|
||||
|
||||
- **Simple**: This is a minimal example showing how to load and run a TensorFlow
|
||||
model in as few lines as possible. It just consists of a single view with a
|
||||
button that executes the model loading and inference when its pressed.
|
||||
|
||||
- **Camera**: This is very similar to the Android TF Classify demo. It loads
|
||||
Inception v3 and outputs its best label estimate for what’s in the live camera
|
||||
view. As with the Android version, you can train your own custom model using
|
||||
TensorFlow for Poets and drop it into this example with minimal code changes.
|
||||
|
||||
- **Benchmark**: is quite close to Simple, but it runs the graph repeatedly and
|
||||
outputs similar statistics to the benchmark tool on Android.
|
||||
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
- Make sure you use the TensorFlow-experimental pod (and not TensorFlow).
|
||||
|
||||
- The TensorFlow-experimental pod is current about ~450MB. The reason it is so
|
||||
big is because we are bundling multiple platforms, and the pod includes all
|
||||
TensorFlow functionality (e.g. operations). The final app size after build is
|
||||
substantially smaller though (~25MB). Working with the complete pod is
|
||||
convenient during development, but see below section on how you can build your
|
||||
own custom TensorFlow library to reduce the size.
|
||||
|
||||
## Building the TensorFlow iOS libraries from source
|
||||
|
||||
While Cocoapods is the quickest and easiest way of getting started, you sometimes
|
||||
need more flexibility to determine which parts of TensorFlow your app should be
|
||||
shipped with. For such cases, you can build the iOS libraries from the
|
||||
sources. [This
|
||||
guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/ios#building-the-tensorflow-ios-libraries-from-source)
|
||||
contains detailed instructions on how to do that.
|
||||
|
@ -1,270 +0,0 @@
|
||||
# Integrating TensorFlow libraries
|
||||
|
||||
Warning: We expect to deprecate TensorFlow Mobile in early 2019
|
||||
|
||||
<div class="caution">
|
||||
<p>
|
||||
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
|
||||
working hard to close the feature gap between TensorFlow Mobile and
|
||||
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
|
||||
will give ample notice to our users when we get to that point and will
|
||||
provide help and support to ensure easy migrations.
|
||||
</p>
|
||||
<p>
|
||||
In the meantime, please use TensorFlow Lite. If you have a feature request,
|
||||
such as a missing op, please post to our <a
|
||||
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
Once you have made some progress on a model that addresses the problem you’re
|
||||
trying to solve, it’s important to test it out inside your application
|
||||
immediately. There are often unexpected differences between your training data
|
||||
and what users actually encounter in the real world, and getting a clear picture
|
||||
of the gap as soon as possible improves the product experience.
|
||||
|
||||
This page talks about how to integrate the TensorFlow libraries into your own
|
||||
mobile applications, once you have already successfully built and deployed the
|
||||
TensorFlow mobile demo apps.
|
||||
|
||||
## Linking the library
|
||||
|
||||
After you've managed to build the examples, you'll probably want to call
|
||||
TensorFlow from one of your existing applications. The very easiest way to do
|
||||
this is to use the Pod installation steps described in
|
||||
<a href="./ios_build.md">Building TensorFlow on iOS</a>, but if you want to build
|
||||
TensorFlow from source (for example to customize which operators are included)
|
||||
you'll need to break out TensorFlow as a framework, include the right header
|
||||
files, and link against the built libraries and dependencies.
|
||||
|
||||
### Android
|
||||
|
||||
For Android, you just need to link in a Java library contained in a JAR file
|
||||
called `libandroid_tensorflow_inference_java.jar`. There are three ways to
|
||||
include this functionality in your program:
|
||||
|
||||
1. Include the jcenter AAR which contains it, as in this
|
||||
[example app](https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tfmobile/build.gradle#L59-L65)
|
||||
|
||||
2. Download the nightly precompiled version from
|
||||
[ci.tensorflow.org](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/).
|
||||
|
||||
3. Build the JAR file yourself using the instructions [in our Android GitHub repo](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android)
|
||||
|
||||
### iOS
|
||||
|
||||
Pulling in the TensorFlow libraries on iOS is a little more complicated. Here is
|
||||
a checklist of what you’ll need to do to your iOS app:
|
||||
|
||||
- Link against tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a, usually
|
||||
by adding `-L/your/path/tensorflow/contrib/makefile/gen/lib/` and
|
||||
`-ltensorflow-core` to your linker flags.
|
||||
|
||||
- Link against the generated protobuf libraries by adding
|
||||
`-L/your/path/tensorflow/contrib/makefile/gen/protobuf_ios/lib` and
|
||||
`-lprotobuf` and `-lprotobuf-lite` to your command line.
|
||||
|
||||
- For the include paths, you need the root of your TensorFlow source folder as
|
||||
the first entry, followed by
|
||||
`tensorflow/contrib/makefile/downloads/protobuf/src`,
|
||||
`tensorflow/contrib/makefile/downloads`,
|
||||
`tensorflow/contrib/makefile/downloads/eigen`, and
|
||||
`tensorflow/contrib/makefile/gen/proto`.
|
||||
|
||||
- Make sure your binary is built with `-force_load` (or the equivalent on your
|
||||
platform), aimed at the TensorFlow library to ensure that it’s linked
|
||||
correctly. More detail on why this is necessary can be found in the next
|
||||
section, [Global constructor magic](#global_constructor_magic). On Linux-like
|
||||
platforms, you’ll need different flags, more like
|
||||
`-Wl,--allow-multiple-definition -Wl,--whole-archive`.
|
||||
|
||||
You’ll also need to link in the Accelerator framework, since this is used to
|
||||
speed up some of the operations.
|
||||
|
||||
## Global constructor magic
|
||||
|
||||
One of the subtlest problems you may run up against is the “No session factory
|
||||
registered for the given session options” error when trying to call TensorFlow
|
||||
from your own application. To understand why this is happening and how to fix
|
||||
it, you need to know a bit about the architecture of TensorFlow.
|
||||
|
||||
The framework is designed to be very modular, with a thin core and a large
|
||||
number of specific objects that are independent and can be mixed and matched as
|
||||
needed. To enable this, the coding pattern in C++ had to let modules easily
|
||||
notify the framework about the services they offer, without requiring a central
|
||||
list that has to be updated separately from each implementation. It also had to
|
||||
allow separate libraries to add their own implementations without needing a
|
||||
recompile of the core.
|
||||
|
||||
To achieve this capability, TensorFlow uses a registration pattern in a lot of
|
||||
places. In the code, it looks like this:
|
||||
|
||||
```
|
||||
class MulKernel : OpKernel {
|
||||
Status Compute(OpKernelContext* context) { … }
|
||||
};
|
||||
REGISTER_KERNEL(MulKernel, “Mul”);
|
||||
```
|
||||
|
||||
This would be in a standalone `.cc` file linked into your application, either
|
||||
as part of the main set of kernels or as a separate custom library. The magic
|
||||
part is that the `REGISTER_KERNEL()` macro is able to inform the core of
|
||||
TensorFlow that it has an implementation of the Mul operation, so that it can be
|
||||
called in any graphs that require it.
|
||||
|
||||
From a programming point of view, this setup is very convenient. The
|
||||
implementation and registration code live in the same file, and adding new
|
||||
implementations is as simple as compiling and linking it in. The difficult part
|
||||
comes from the way that the `REGISTER_KERNEL()` macro is implemented. C++
|
||||
doesn’t offer a good mechanism for doing this sort of registration, so we have
|
||||
to resort to some tricky code. Under the hood, the macro is implemented so that
|
||||
it produces something like this:
|
||||
|
||||
```
|
||||
class RegisterMul {
|
||||
public:
|
||||
RegisterMul() {
|
||||
global_kernel_registry()->Register(“Mul”, [](){
|
||||
return new MulKernel()
|
||||
});
|
||||
}
|
||||
};
|
||||
RegisterMul g_register_mul;
|
||||
```
|
||||
|
||||
This sets up a class `RegisterMul` with a constructor that tells the global
|
||||
kernel registry what function to call when somebody asks it how to create a
|
||||
“Mul” kernel. Then there’s a global object of that class, and so the constructor
|
||||
should be called at the start of any program.
|
||||
|
||||
While this may sound sensible, the unfortunate part is that the global object
|
||||
that’s defined is not used by any other code, so linkers not designed with this
|
||||
in mind will decide that it can be deleted. As a result, the constructor is
|
||||
never called, and the class is never registered. All sorts of modules use this
|
||||
pattern in TensorFlow, and it happens that `Session` implementations are the
|
||||
first to be looked for when the code is run, which is why it shows up as the
|
||||
characteristic error when this problem occurs.
|
||||
|
||||
The solution is to force the linker to not strip any code from the library, even
|
||||
if it believes it’s unused. On iOS, this step can be accomplished with the
|
||||
`-force_load` flag, specifying a library path, and on Linux you need
|
||||
`--whole-archive`. These persuade the linker to not be as aggressive about
|
||||
stripping, and should retain the globals.
|
||||
|
||||
The actual implementation of the various `REGISTER_*` macros is a bit more
|
||||
complicated in practice, but they all suffer the same underlying problem. If
|
||||
you’re interested in how they work, [op_kernel.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_kernel.h#L1091)
|
||||
is a good place to start investigating.
|
||||
|
||||
## Protobuf problems
|
||||
|
||||
TensorFlow relies on
|
||||
the [Protocol Buffer](https://developers.google.com/protocol-buffers/) library,
|
||||
commonly known as protobuf. This library takes definitions of data structures
|
||||
and produces serialization and access code for them in a variety of
|
||||
languages. The tricky part is that this generated code needs to be linked
|
||||
against shared libraries for the exact same version of the framework that was
|
||||
used for the generator. This can be an issue when `protoc`, the tool used to
|
||||
generate the code, is from a different version of protobuf than the libraries in
|
||||
the standard linking and include paths. For example, you might be using a copy
|
||||
of `protoc` that was built locally in `~/projects/protobuf-3.0.1.a`, but you have
|
||||
libraries installed at `/usr/local/lib` and `/usr/local/include` that are from
|
||||
3.0.0.
|
||||
|
||||
The symptoms of this issue are errors during the compilation or linking phases
|
||||
with protobufs. Usually, the build tools take care of this, but if you’re using
|
||||
the makefile, make sure you’re building the protobuf library locally and using
|
||||
it, as shown in [this Makefile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/Makefile#L18).
|
||||
|
||||
Another situation that can cause problems is when protobuf headers and source
|
||||
files need to be generated as part of the build process. This process makes
|
||||
building more complex, since the first phase has to be a pass over the protobuf
|
||||
definitions to create all the needed code files, and only after that can you go
|
||||
ahead and do a build of the library code.
|
||||
|
||||
### Multiple versions of protobufs in the same app
|
||||
|
||||
Protobufs generate headers that are needed as part of the C++ interface to the
|
||||
overall TensorFlow library. This complicates using the library as a standalone
|
||||
framework.
|
||||
|
||||
If your application is already using version 1 of the protocol buffers library,
|
||||
you may have trouble integrating TensorFlow because it requires version 2. If
|
||||
you just try to link both versions into the same binary, you’ll see linking
|
||||
errors because some of the symbols clash. To solve this particular problem, we
|
||||
have an experimental script at [rename_protobuf.sh](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/rename_protobuf.sh).
|
||||
|
||||
You need to run this as part of the makefile build, after you’ve downloaded all
|
||||
the dependencies:
|
||||
|
||||
```
|
||||
tensorflow/contrib/makefile/download_dependencies.sh
|
||||
tensorflow/contrib/makefile/rename_protobuf.sh
|
||||
```
|
||||
|
||||
## Calling the TensorFlow API
|
||||
|
||||
Once you have the framework available, you then need to call into it. The usual
|
||||
pattern is that you first load your model, which represents a preset set of
|
||||
numeric computations, and then you run inputs through that model (for example,
|
||||
images from a camera) and receive outputs (for example, predicted labels).
|
||||
|
||||
On Android, we provide the Java Inference Library that is focused on just this
|
||||
use case, while on iOS and Raspberry Pi you call directly into the C++ API.
|
||||
|
||||
### Android
|
||||
|
||||
Here’s what a typical Inference Library sequence looks like on Android:
|
||||
|
||||
```
|
||||
// Load the model from disk.
|
||||
TensorFlowInferenceInterface inferenceInterface =
|
||||
new TensorFlowInferenceInterface(assetManager, modelFilename);
|
||||
|
||||
// Copy the input data into TensorFlow.
|
||||
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
|
||||
|
||||
// Run the inference call.
|
||||
inferenceInterface.run(outputNames, logStats);
|
||||
|
||||
// Copy the output Tensor back into the output array.
|
||||
inferenceInterface.fetch(outputName, outputs);
|
||||
```
|
||||
|
||||
You can find the source of this code in the [Android examples](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java#L107).
|
||||
|
||||
### iOS and Raspberry Pi
|
||||
|
||||
Here’s the equivalent code for iOS and Raspberry Pi:
|
||||
|
||||
```
|
||||
// Load the model.
|
||||
PortableReadFileToProto(file_path, &tensorflow_graph);
|
||||
|
||||
// Create a session from the model.
|
||||
tensorflow::Status s = session->Create(tensorflow_graph);
|
||||
if (!s.ok()) {
|
||||
LOG(FATAL) << "Could not create TensorFlow Graph: " << s;
|
||||
}
|
||||
|
||||
// Run the model.
|
||||
std::string input_layer = "input";
|
||||
std::string output_layer = "output";
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
tensorflow::Status run_status = session->Run({\{input_layer, image_tensor}},
|
||||
{output_layer}, {}, &outputs);
|
||||
if (!run_status.ok()) {
|
||||
LOG(FATAL) << "Running model failed: " << run_status;
|
||||
}
|
||||
|
||||
// Access the output data.
|
||||
tensorflow::Tensor* output = &outputs[0];
|
||||
```
|
||||
|
||||
This is all based on the
|
||||
[iOS sample code](https://www.tensorflow.org/code/tensorflow/examples/ios/simple/RunModelViewController.mm),
|
||||
but there’s nothing iOS-specific; the same code should be usable on any platform
|
||||
that supports C++.
|
||||
|
||||
You can also find specific examples for Raspberry Pi
|
||||
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/pi_examples/label_image/label_image.cc).
|
@ -1,518 +0,0 @@
|
||||
# Optimizing for mobile
|
||||
|
||||
Warning: We expect to deprecate TensorFlow Mobile in early 2019
|
||||
|
||||
<div class="caution">
|
||||
<p>
|
||||
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
|
||||
working hard to close the feature gap between TensorFlow Mobile and
|
||||
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
|
||||
will give ample notice to our users when we get to that point and will
|
||||
provide help and support to ensure easy migrations.
|
||||
</p>
|
||||
<p>
|
||||
In the meantime, please use TensorFlow Lite. If you have a feature request,
|
||||
such as a missing op, please post to our <a
|
||||
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
There are some special issues that you have to deal with when you’re trying to
|
||||
ship on mobile or embedded devices, and you’ll need to think about these as
|
||||
you’re developing your model.
|
||||
|
||||
These issues are:
|
||||
|
||||
- Model and Binary Size
|
||||
- App speed and model loading speed
|
||||
- Performance and threading
|
||||
|
||||
We'll discuss a few of these below.
|
||||
|
||||
## What are the minimum device requirements for TensorFlow?
|
||||
|
||||
You need at least one megabyte of program memory and several megabytes of RAM to
|
||||
run the base TensorFlow runtime, so it’s not suitable for DSPs or
|
||||
microcontrollers. Other than those, the biggest constraint is usually the
|
||||
calculation speed of the device, and whether you can run the model you need for
|
||||
your application with a low enough latency. You can use the benchmarking tools
|
||||
in [How to Profile your Model](#how_to_profile_your_model) to get an idea of how
|
||||
many FLOPs are required for a model, and then use that to make rule-of-thumb
|
||||
estimates of how fast they will run on different devices. For example, a modern
|
||||
smartphone might be able to run 10 GFLOPs per second, so the best you could hope
|
||||
for from a 5 GFLOP model is two frames per second, though you may do worse
|
||||
depending on what the exact computation patterns are.
|
||||
|
||||
This model dependence means that it’s possible to run TensorFlow even on very
|
||||
old or constrained phones, as long as you optimize your network to fit within
|
||||
the latency budget and possibly within limited RAM too. For memory usage, you
|
||||
mostly need to make sure that the intermediate buffers that TensorFlow creates
|
||||
aren’t too large, which you can examine in the benchmark output too.
|
||||
|
||||
## Speed
|
||||
|
||||
One of the highest priorities of most model deployments is figuring out how to
|
||||
run the inference fast enough to give a good user experience. The first place to
|
||||
start is by looking at the total number of floating point operations that are
|
||||
required to execute the graph. You can get a very rough estimate of this by
|
||||
using the `benchmark_model` tool:
|
||||
|
||||
bazel build -c opt tensorflow/tools/benchmark:benchmark_model && \
|
||||
bazel-bin/tensorflow/tools/benchmark/benchmark_model \
|
||||
--graph=/tmp/inception_graph.pb --input_layer="Mul:0" \
|
||||
--input_layer_shape="1,299,299,3" --input_layer_type="float" \
|
||||
--output_layer="softmax:0" --show_run_order=false --show_time=false \
|
||||
--show_memory=false --show_summary=true --show_flops=true --logtostderr
|
||||
|
||||
This should show you an estimate of how many operations are needed to run the
|
||||
graph. You can then use that information to figure out how feasible your model
|
||||
is to run on the devices you’re targeting. For an example, a high-end phone from
|
||||
2016 might be able to do 20 billion FLOPs per second, so the best speed you
|
||||
could hope for from a model that requires 10 billion FLOPs is around 500ms. On a
|
||||
device like the Raspberry Pi 3 that can do about 5 billion FLOPs, you may only
|
||||
get one inference every two seconds.
|
||||
|
||||
Having this estimate helps you plan for what you’ll be able to realistically
|
||||
achieve on a device. If the model is using too many ops, then there are a lot of
|
||||
opportunities to optimize the architecture to reduce that number.
|
||||
|
||||
Advanced techniques include [SqueezeNet](https://arxiv.org/abs/1602.07360)
|
||||
and [MobileNet](https://arxiv.org/abs/1704.04861), which are architectures
|
||||
designed to produce models for mobile -- lean and fast but with a small accuracy
|
||||
cost. You can also just look at alternative models, even older ones, which may
|
||||
be smaller. For example, Inception v1 only has around 7 million parameters,
|
||||
compared to Inception v3’s 24 million, and requires only 3 billion FLOPs rather
|
||||
than 9 billion for v3.
|
||||
|
||||
## Model Size
|
||||
|
||||
Models that run on a device need to be stored somewhere on the device, and very
|
||||
large neural networks can be hundreds of megabytes. Most users are reluctant to
|
||||
download very large app bundles from app stores, so you want to make your model
|
||||
as small as possible. Furthermore, smaller neural networks can persist in and
|
||||
out of a mobile device's memory faster.
|
||||
|
||||
To understand how large your network will be on disk, start by looking at the
|
||||
size on disk of your `GraphDef` file after you’ve run `freeze_graph` and
|
||||
`strip_unused_nodes` on it (see <a href="./prepare_models.md">Preparing models</a> for
|
||||
more details on these tools), since then it should only contain
|
||||
inference-related nodes. To double-check that your results are as expected, run
|
||||
the `summarize_graph` tool to see how many parameters are in constants:
|
||||
|
||||
bazel build tensorflow/tools/graph_transforms:summarize_graph && \
|
||||
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
|
||||
--in_graph=/tmp/tensorflow_inception_graph.pb
|
||||
|
||||
That command should give you output that looks something like this:
|
||||
|
||||
No inputs spotted.
|
||||
Found 1 possible outputs: (name=softmax, op=Softmax)
|
||||
Found 23885411 (23.89M) const parameters, 0 (0) variable parameters,
|
||||
and 99 control_edges
|
||||
Op types used: 489 Const, 99 CheckNumerics, 99 Identity, 94
|
||||
BatchNormWithGlobalNormalization, 94 Conv2D, 94 Relu, 11 Concat, 9 AvgPool,
|
||||
5 MaxPool, 1 Sub, 1 Softmax, 1 ResizeBilinear, 1 Reshape, 1 Mul, 1 MatMul,
|
||||
1 ExpandDims, 1 DecodeJpeg, 1 Cast, 1 BiasAdd
|
||||
|
||||
The important part for our current purposes is the number of const
|
||||
parameters. In most models these will be stored as 32-bit floats to start, so if
|
||||
you multiply the number of const parameters by four, you should get something
|
||||
that’s close to the size of the file on disk. You can often get away with only
|
||||
eight-bits per parameter with very little loss of accuracy in the final result,
|
||||
so if your file size is too large you can try using
|
||||
<a href="https://www.tensorflow.org/performance/quantization">quantize_weights</a>
|
||||
to transform the parameters down.
|
||||
|
||||
bazel build tensorflow/tools/graph_transforms:transform_graph && \
|
||||
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||
--in_graph=/tmp/tensorflow_inception_optimized.pb \
|
||||
--out_graph=/tmp/tensorflow_inception_quantized.pb \
|
||||
--inputs='Mul:0' --outputs='softmax:0' --transforms='quantize_weights'
|
||||
|
||||
If you look at the resulting file size, you should see that it’s about a quarter
|
||||
of the original at 23MB.
|
||||
|
||||
Another transform is `round_weights`, which doesn't make the file smaller, but it
|
||||
makes the file compressible to about the same size as when `quantize_weights` is
|
||||
used. This is particularly useful for mobile development, taking advantage of
|
||||
the fact that app bundles are compressed before they’re downloaded by consumers.
|
||||
|
||||
The original file does not compress well with standard algorithms, because the
|
||||
bit patterns of even very similar numbers can be very different. The
|
||||
`round_weights` transform keeps the weight parameters stored as floats, but
|
||||
rounds them to a set number of step values. This means there are a lot more
|
||||
repeated byte patterns in the stored model, and so compression can often bring
|
||||
the size down dramatically, in many cases to near the size it would be if they
|
||||
were stored as eight bit.
|
||||
|
||||
Another advantage of `round_weights` is that the framework doesn’t have to
|
||||
allocate a temporary buffer to unpack the parameters into, as we have to when
|
||||
we just use `quantize_weights`. This saves a little bit of latency (though the
|
||||
results should be cached so it’s only costly on the first run) and makes it
|
||||
possible to use memory mapping, as described later.
|
||||
|
||||
## Binary Size
|
||||
|
||||
One of the biggest differences between mobile and server development is the
|
||||
importance of binary size. On desktop machines it’s not unusual to have
|
||||
executables that are hundreds of megabytes on disk, but for mobile and embedded
|
||||
apps it’s vital to keep the binary as small as possible so that user downloads
|
||||
are easy. As mentioned above, TensorFlow only includes a subset of op
|
||||
implementations by default, but this still results in a 12 MB final
|
||||
executable. To reduce this, you can set up the library to only include the
|
||||
implementations of the ops that you actually need, based on automatically
|
||||
analyzing your model. To use it:
|
||||
|
||||
- Run `tools/print_required_ops/print_selective_registration_header.py` on your
|
||||
model to produce a header file that only enables the ops it uses.
|
||||
|
||||
- Place the `ops_to_register.h` file somewhere that the compiler can find
|
||||
it. This can be in the root of your TensorFlow source folder.
|
||||
|
||||
- Build TensorFlow with `SELECTIVE_REGISTRATION` defined, for example by passing
|
||||
in `--copts=”-DSELECTIVE_REGISTRATION”` to your Bazel build command.
|
||||
|
||||
This process recompiles the library so that only the needed ops and types are
|
||||
included, which can dramatically reduce the executable size. For example, with
|
||||
Inception v3, the new size is only 1.5MB.
|
||||
|
||||
## How to Profile your Model
|
||||
|
||||
Once you have an idea of what your device's peak performance range is, it’s
|
||||
worth looking at its actual current performance. Using a standalone TensorFlow
|
||||
benchmark, rather than running it inside a larger app, helps isolate just the
|
||||
Tensorflow contribution to the
|
||||
latency. The
|
||||
[tensorflow/tools/benchmark](https://www.tensorflow.org/code/tensorflow/tools/benchmark/) tool
|
||||
is designed to help you do this. To run it on Inception v3 on your desktop
|
||||
machine, build this benchmark model:
|
||||
|
||||
bazel build -c opt tensorflow/tools/benchmark:benchmark_model && \
|
||||
bazel-bin/tensorflow/tools/benchmark/benchmark_model \
|
||||
--graph=/tmp/tensorflow_inception_graph.pb --input_layer="Mul" \
|
||||
--input_layer_shape="1,299,299,3" --input_layer_type="float" \
|
||||
--output_layer="softmax:0" --show_run_order=false --show_time=false \
|
||||
--show_memory=false --show_summary=true --show_flops=true --logtostderr
|
||||
|
||||
You should see output that looks something like this:
|
||||
|
||||
<pre>
|
||||
============================== Top by Computation Time ==============================
|
||||
[node
|
||||
type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [Name]
|
||||
Conv2D 22.859 14.212 13.700 4.972% 4.972% 3871.488 conv_4/Conv2D
|
||||
Conv2D 8.116 8.964 11.315 4.106% 9.078% 5531.904 conv_2/Conv2D
|
||||
Conv2D 62.066 16.504 7.274 2.640% 11.717% 443.904 mixed_3/conv/Conv2D
|
||||
Conv2D 2.530 6.226 4.939 1.792% 13.510% 2765.952 conv_1/Conv2D
|
||||
Conv2D 55.585 4.605 4.665 1.693% 15.203% 313.600 mixed_2/tower/conv_1/Conv2D
|
||||
Conv2D 127.114 5.469 4.630 1.680% 16.883% 81.920 mixed_10/conv/Conv2D
|
||||
Conv2D 47.391 6.994 4.588 1.665% 18.548% 313.600 mixed_1/tower/conv_1/Conv2D
|
||||
Conv2D 39.463 7.878 4.336 1.574% 20.122% 313.600 mixed/tower/conv_1/Conv2D
|
||||
Conv2D 127.113 4.192 3.894 1.413% 21.535% 114.688 mixed_10/tower_1/conv/Conv2D
|
||||
Conv2D 70.188 5.205 3.626 1.316% 22.850% 221.952 mixed_4/conv/Conv2D
|
||||
|
||||
============================== Summary by node type ==============================
|
||||
[Node type] [count] [avg ms] [avg %] [cdf %] [mem KB]
|
||||
Conv2D 94 244.899 88.952% 88.952% 35869.953
|
||||
BiasAdd 95 9.664 3.510% 92.462% 35873.984
|
||||
AvgPool 9 7.990 2.902% 95.364% 7493.504
|
||||
Relu 94 5.727 2.080% 97.444% 35869.953
|
||||
MaxPool 5 3.485 1.266% 98.710% 3358.848
|
||||
Const 192 1.727 0.627% 99.337% 0.000
|
||||
Concat 11 1.081 0.393% 99.730% 9892.096
|
||||
MatMul 1 0.665 0.242% 99.971% 4.032
|
||||
Softmax 1 0.040 0.015% 99.986% 4.032
|
||||
<> 1 0.032 0.012% 99.997% 0.000
|
||||
Reshape 1 0.007 0.003% 100.000% 0.000
|
||||
|
||||
Timings (microseconds): count=50 first=330849 curr=274803 min=232354 max=415352 avg=275563 std=44193
|
||||
Memory (bytes): count=50 curr=128366400(all same)
|
||||
514 nodes defined 504 nodes observed
|
||||
</pre>
|
||||
|
||||
This is the summary view, which is enabled by the show_summary flag. To
|
||||
interpret it, the first table is a list of the nodes that took the most time, in
|
||||
order by how long they took. From left to right, the columns are:
|
||||
|
||||
- Node type, what kind of operation this was.
|
||||
|
||||
- Start time of the op, showing where it falls in the sequence of operations.
|
||||
|
||||
- First time in milliseconds. This is how long the operation took on the first
|
||||
run of the benchmark, since by default 20 runs are executed to get more
|
||||
reliable statistics. The first time is useful to spot which ops are doing
|
||||
expensive calculations on the first run, and then caching the results.
|
||||
|
||||
- Average time for the operation across all runs, in milliseconds.
|
||||
|
||||
- What percentage of the total time for one run the op took. This is useful to
|
||||
understand where the hotspots are.
|
||||
|
||||
- The cumulative total time of this and the previous ops in the table. This is
|
||||
handy for understanding what the distribution of work is across the layers, to
|
||||
see if just a few of the nodes are taking up most of the time.
|
||||
|
||||
- The amount of memory consumed by outputs of this type of op.
|
||||
|
||||
- Name of the node.
|
||||
|
||||
The second table is similar, but instead of breaking down the timings by
|
||||
particular named nodes, it groups them by the kind of op. This is very useful to
|
||||
understand which op implementations you might want to optimize or eliminate from
|
||||
your graph. The table is arranged with the most costly operations at the start,
|
||||
and only shows the top ten entries, with a placeholder for other nodes. The
|
||||
columns from left to right are:
|
||||
|
||||
- Type of the nodes being analyzed.
|
||||
|
||||
- Accumulated average time taken by all nodes of this type, in milliseconds.
|
||||
|
||||
- What percentage of the total time was taken by this type of operation.
|
||||
|
||||
- Cumulative time taken by this and op types higher in the table, so you can
|
||||
understand the distribution of the workload.
|
||||
|
||||
- How much memory the outputs of this op type took up.
|
||||
|
||||
Both of these tables are set up so that you can easily copy and paste their
|
||||
results into spreadsheet documents, since they are output with tabs as
|
||||
separators between the columns. The summary by node type can be the most useful
|
||||
when looking for optimization opportunities, since it’s a pointer to the code
|
||||
that’s taking the most time. In this case, you can see that the Conv2D ops are
|
||||
almost 90% of the execution time. This is a sign that the graph is pretty
|
||||
optimal, since convolutions and matrix multiplies are expected to be the bulk of
|
||||
a neural network’s computing workload.
|
||||
|
||||
As a rule of thumb, it’s more worrying if you see a lot of other operations
|
||||
taking up more than a small fraction of the time. For neural networks, the ops
|
||||
that don’t involve large matrix multiplications should usually be dwarfed by the
|
||||
ones that do, so if you see a lot of time going into those it’s a sign that
|
||||
either your network is non-optimally constructed, or the code implementing those
|
||||
ops is not as optimized as it could
|
||||
be. [Performance bugs](https://github.com/tensorflow/tensorflow/issues) or
|
||||
patches are always welcome if you do encounter this situation, especially if
|
||||
they include an attached model exhibiting this behavior and the command line
|
||||
used to run the benchmark tool on it.
|
||||
|
||||
The run above was on your desktop, but the tool also works on Android, which is
|
||||
where it’s most useful for mobile development. Here’s an example command line to
|
||||
run it on a 64-bit ARM device:
|
||||
|
||||
bazel build -c opt --config=android_arm64 \
|
||||
tensorflow/tools/benchmark:benchmark_model
|
||||
adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp
|
||||
adb push /tmp/tensorflow_inception_graph.pb /data/local/tmp/
|
||||
adb shell '/data/local/tmp/benchmark_model \
|
||||
--graph=/data/local/tmp/tensorflow_inception_graph.pb --input_layer="Mul" \
|
||||
--input_layer_shape="1,299,299,3" --input_layer_type="float" \
|
||||
--output_layer="softmax:0" --show_run_order=false --show_time=false \
|
||||
--show_memory=false --show_summary=true'
|
||||
|
||||
You can interpret the results in exactly the same way as the desktop version
|
||||
above. If you have any trouble figuring out what the right input and output
|
||||
names and types are, take a look at the
|
||||
<a href="./prepare_models">Preparing models</a>
|
||||
page for details about detecting these for your model, and look at the
|
||||
`summarize_graph` tool which may give you
|
||||
helpful information.
|
||||
|
||||
There isn’t good support for command line tools on iOS, so instead there’s a
|
||||
separate example
|
||||
at
|
||||
[tensorflow/examples/ios/benchmark](https://www.tensorflow.org/code/tensorflow/examples/ios/benchmark) that
|
||||
packages the same functionality inside a standalone app. This outputs the
|
||||
statistics to both the screen of the device and the debug log. If you want
|
||||
on-screen statistics for the Android example apps, you can turn them on by
|
||||
pressing the volume-up button.
|
||||
|
||||
## Profiling within your own app
|
||||
|
||||
The output you see from the benchmark tool is generated from modules that are
|
||||
included as part of the standard TensorFlow runtime, which means you have access
|
||||
to them within your own applications too. You can see an example of how to do
|
||||
that [here](https://www.tensorflow.org/code/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm?l=139).
|
||||
|
||||
The basic steps are:
|
||||
|
||||
1. Create a StatSummarizer object:
|
||||
|
||||
tensorflow::StatSummarizer stat_summarizer(tensorflow_graph);
|
||||
|
||||
2. Set up the options:
|
||||
|
||||
tensorflow::RunOptions run_options;
|
||||
run_options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
|
||||
tensorflow::RunMetadata run_metadata;
|
||||
|
||||
3. Run the graph:
|
||||
|
||||
run_status = session->Run(run_options, inputs, output_layer_names, {},
|
||||
output_layers, &run_metadata);
|
||||
|
||||
4. Calculate the results and print them out:
|
||||
|
||||
assert(run_metadata.has_step_stats());
|
||||
const tensorflow::StepStats& step_stats = run_metadata.step_stats();
|
||||
stat_summarizer->ProcessStepStats(step_stats);
|
||||
stat_summarizer->PrintStepStats();
|
||||
|
||||
## Visualizing Models
|
||||
|
||||
The most effective way to speed up your code is by altering your model so it
|
||||
does less work. To do that, you need to understand what your model is doing, and
|
||||
visualizing it is a good first step. To get a high-level overview of your graph,
|
||||
use [TensorBoard](https://github.com/tensorflow/tensorboard).
|
||||
|
||||
## Threading
|
||||
|
||||
The desktop version of TensorFlow has a sophisticated threading model, and will
|
||||
try to run multiple operations in parallel if it can. In our terminology this is
|
||||
called “inter-op parallelism” (though to avoid confusion with “intra-op”, you
|
||||
could think of it as “between-op” instead), and can be set by specifying
|
||||
`inter_op_parallelism_threads` in the session options.
|
||||
|
||||
By default, mobile devices run operations serially; that is,
|
||||
`inter_op_parallelism_threads` is set to 1. Mobile processors usually have few
|
||||
cores and a small cache, so running multiple operations accessing disjoint parts
|
||||
of memory usually doesn’t help performance. “Intra-op parallelism” (or
|
||||
“within-op”) can be very helpful though, especially for computation-bound
|
||||
operations like convolutions where different threads can feed off the same small
|
||||
set of memory.
|
||||
|
||||
On mobile, how many threads an op will use is set to the number of cores by
|
||||
default, or 2 when the number of cores can't be determined. You can override the
|
||||
default number of threads that ops are using by setting
|
||||
`intra_op_parallelism_threads` in the session options. It’s a good idea to
|
||||
reduce the default if your app has its own threads doing heavy processing, so
|
||||
that they don’t interfere with each other.
|
||||
|
||||
To see more details on session options, look at [ConfigProto](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
|
||||
|
||||
## Retrain with mobile data
|
||||
|
||||
The biggest cause of accuracy problems when running models on mobile apps is
|
||||
unrepresentative training data. For example, most of the Imagenet photos are
|
||||
well-framed so that the object is in the center of the picture, well-lit, and
|
||||
shot with a normal lens. Photos from mobile devices are often poorly framed,
|
||||
badly lit, and can have fisheye distortions, especially selfies.
|
||||
|
||||
The solution is to expand your training set with data actually captured from
|
||||
your application. This step can involve extra work, since you’ll have to label
|
||||
the examples yourself, but even if you just use it to expand your original
|
||||
training data, it can help the training set dramatically. Improving the training
|
||||
set by doing this, and by fixing other quality issues like duplicates or badly
|
||||
labeled examples is the single best way to improve accuracy. It’s usually a
|
||||
bigger help than altering your model architecture or using different techniques.
|
||||
|
||||
## Reducing model loading time and/or memory footprint
|
||||
|
||||
Most operating systems allow you to load a file using memory mapping, rather
|
||||
than going through the usual I/O APIs. Instead of allocating an area of memory
|
||||
on the heap and then copying bytes from disk into it, you simply tell the
|
||||
operating system to make the entire contents of a file appear directly in
|
||||
memory. This has several advantages:
|
||||
|
||||
* Speeds loading
|
||||
* Reduces paging (increases performance)
|
||||
* Does not count towards RAM budget for your app
|
||||
|
||||
TensorFlow has support for memory mapping the weights that form the bulk of most
|
||||
model files. Because of limitations in the `ProtoBuf` serialization format, we
|
||||
have to make a few changes to our model loading and processing code. The
|
||||
way memory mapping works is that we have a single file where the first part is a
|
||||
normal `GraphDef` serialized into the protocol buffer wire format, but then the
|
||||
weights are appended in a form that can be directly mapped.
|
||||
|
||||
To create this file, run the
|
||||
`tensorflow/contrib/util:convert_graphdef_memmapped_format` tool. This takes in
|
||||
a `GraphDef` file that’s been run through `freeze_graph` and converts it to the
|
||||
format that has the weights appended at the end. Since that file’s no longer a
|
||||
standard `GraphDef` protobuf, you then need to make some changes to the loading
|
||||
code. You can see an example of this in
|
||||
the
|
||||
[iOS Camera demo app](https://www.tensorflow.org/code/tensorflow/examples/ios/camera/tensorflow_utils.mm?l=147),
|
||||
in the `LoadMemoryMappedModel()` function.
|
||||
|
||||
The same code (with the Objective C calls for getting the filenames substituted)
|
||||
can be used on other platforms too. Because we’re using memory mapping, we need
|
||||
to start by creating a special TensorFlow environment object that’s set up with
|
||||
the file we’ll be using:
|
||||
|
||||
std::unique_ptr<tensorflow::MemmappedEnv> memmapped_env;
|
||||
memmapped_env->reset(
|
||||
new tensorflow::MemmappedEnv(tensorflow::Env::Default()));
|
||||
tensorflow::Status mmap_status =
|
||||
(memmapped_env->get())->InitializeFromFile(file_path);
|
||||
|
||||
You then need to pass in this environment to subsequent calls, like this one for
|
||||
loading the graph:
|
||||
|
||||
tensorflow::GraphDef tensorflow_graph;
|
||||
tensorflow::Status load_graph_status = ReadBinaryProto(
|
||||
memmapped_env->get(),
|
||||
tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&tensorflow_graph);
|
||||
|
||||
You also need to create the session with a pointer to the environment you’ve
|
||||
created:
|
||||
|
||||
tensorflow::SessionOptions options;
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(::tensorflow::OptimizerOptions::L0);
|
||||
options.env = memmapped_env->get();
|
||||
|
||||
tensorflow::Session* session_pointer = nullptr;
|
||||
tensorflow::Status session_status =
|
||||
tensorflow::NewSession(options, &session_pointer);
|
||||
|
||||
One thing to notice here is that we’re also disabling automatic optimizations,
|
||||
since in some cases these will fold constant sub-trees, and so create copies of
|
||||
tensor values that we don’t want and use up more RAM.
|
||||
|
||||
Once you’ve gone through these steps, you can use the session and graph as
|
||||
normal, and you should see a reduction in loading time and memory usage.
|
||||
|
||||
## Protecting model files from easy copying
|
||||
|
||||
By default, your models will be stored in the standard serialized protobuf
|
||||
format on disk. In theory this means that anybody can copy your model, which you
|
||||
may not want. However, in practice, most models are so application-specific and
|
||||
obfuscated by optimizations that the risk is similar to that of competitors
|
||||
disassembling and reusing your code, but if you do want to make it tougher for
|
||||
casual users to access your files it is possible to take some basic steps.
|
||||
|
||||
Most of our examples use
|
||||
the
|
||||
[ReadBinaryProto()](https://www.tensorflow.org/code/tensorflow/core/platform/env.cc?q=core/platform/env.cc&l=409) convenience
|
||||
call to load a `GraphDef` from disk. This does require an unencrypted protobuf on
|
||||
disk. Luckily though, the implementation of the call is pretty straightforward
|
||||
and it should be easy to write an equivalent that can decrypt in memory. Here's
|
||||
some code that shows how you can read and decrypt a protobuf using your own
|
||||
decryption routine:
|
||||
|
||||
Status ReadEncryptedProto(Env* env, const string& fname,
|
||||
::tensorflow::protobuf::MessageLite* proto) {
|
||||
string data;
|
||||
TF_RETURN_IF_ERROR(ReadFileToString(env, fname, &data));
|
||||
|
||||
DecryptData(&data); // Your own function here.
|
||||
|
||||
if (!proto->ParseFromString(&data)) {
|
||||
TF_RETURN_IF_ERROR(stream->status());
|
||||
return errors::DataLoss("Can't parse ", fname, " as binary proto");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
To use this you’d need to define the DecryptData() function yourself. It could
|
||||
be as simple as something like:
|
||||
|
||||
void DecryptData(string* data) {
|
||||
for (int i = 0; i < data.size(); ++i) {
|
||||
data[i] = data[i] ^ 0x23;
|
||||
}
|
||||
}
|
||||
|
||||
You may want something more complex, but exactly what you’ll need is outside the
|
||||
current scope here.
|
@ -1,318 +0,0 @@
|
||||
# Preparing models for mobile deployment
|
||||
|
||||
Warning: We expect to deprecate TensorFlow Mobile in early 2019
|
||||
|
||||
<div class="caution">
|
||||
<p>
|
||||
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
|
||||
working hard to close the feature gap between TensorFlow Mobile and
|
||||
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
|
||||
will give ample notice to our users when we get to that point and will
|
||||
provide help and support to ensure easy migrations.
|
||||
</p>
|
||||
<p>
|
||||
In the meantime, please use TensorFlow Lite. If you have a feature request,
|
||||
such as a missing op, please post to our <a
|
||||
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
The requirements for storing model information during training are very
|
||||
different from when you want to release it as part of a mobile app. This section
|
||||
covers the tools involved in converting from a training model to something
|
||||
releasable in production.
|
||||
|
||||
## What is up with all the different saved file formats?
|
||||
|
||||
You may find yourself getting very confused by all the different ways that
|
||||
TensorFlow can save out graphs. To help, here’s a rundown of some of the
|
||||
different components, and what they are used for. The objects are mostly defined
|
||||
and serialized as protocol buffers:
|
||||
|
||||
- [NodeDef](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto):
|
||||
Defines a single operation in a model. It has a unique name, a list of the
|
||||
names of other nodes it pulls inputs from, the operation type it implements
|
||||
(for example `Add`, or `Mul`), and any attributes that are needed to control
|
||||
that operation. This is the basic unit of computation for TensorFlow, and all
|
||||
work is done by iterating through a network of these nodes, applying each one
|
||||
in turn. One particular operation type that’s worth knowing about is `Const`,
|
||||
since this holds information about a constant. This may be a single, scalar
|
||||
number or string, but it can also hold an entire multi-dimensional tensor
|
||||
array. The values for a `Const` are stored inside the `NodeDef`, and so large
|
||||
constants can take up a lot of room when serialized.
|
||||
|
||||
- [Checkpoint](https://www.tensorflow.org/code/tensorflow/core/util/tensor_bundle/tensor_bundle.h). Another
|
||||
way of storing values for a model is by using `Variable` ops. Unlike `Const`
|
||||
ops, these don’t store their content as part of the `NodeDef`, so they take up
|
||||
very little space within the `GraphDef` file. Instead their values are held in
|
||||
RAM while a computation is running, and then saved out to disk as checkpoint
|
||||
files periodically. This typically happens as a neural network is being
|
||||
trained and weights are updated, so it’s a time-critical operation, and it may
|
||||
happen in a distributed fashion across many workers, so the file format has to
|
||||
be both fast and flexible. They are stored as multiple checkpoint files,
|
||||
together with metadata files that describe what’s contained within the
|
||||
checkpoints. When you’re referring to a checkpoint in the API (for example
|
||||
when passing a filename in as a command line argument), you’ll use the common
|
||||
prefix for a set of related files. If you had these files:
|
||||
|
||||
/tmp/model/model-chkpt-1000.data-00000-of-00002
|
||||
/tmp/model/model-chkpt-1000.data-00001-of-00002
|
||||
/tmp/model/model-chkpt-1000.index
|
||||
/tmp/model/model-chkpt-1000.meta
|
||||
|
||||
You would refer to them as `/tmp/model/chkpt-1000`.
|
||||
|
||||
- [GraphDef](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto):
|
||||
Has a list of `NodeDefs`, which together define the computational graph to
|
||||
execute. During training, some of these nodes will be `Variables`, and so if
|
||||
you want to have a complete graph you can run, including the weights, you’ll
|
||||
need to call a restore operation to pull those values from
|
||||
checkpoints. Because checkpoint loading has to be flexible to deal with all of
|
||||
the training requirements, this can be tricky to implement on mobile and
|
||||
embedded devices, especially those with no proper file system available like
|
||||
iOS. This is where
|
||||
the
|
||||
[`freeze_graph.py`](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py) script
|
||||
comes in handy. As mentioned above, `Const` ops store their values as part of
|
||||
the `NodeDef`, so if all the `Variable` weights are converted to `Const` nodes,
|
||||
then we only need a single `GraphDef` file to hold the model architecture and
|
||||
the weights. Freezing the graph handles the process of loading the
|
||||
checkpoints, and then converts all Variables to Consts. You can then load the
|
||||
resulting file in a single call, without having to restore variable values
|
||||
from checkpoints. One thing to watch out for with `GraphDef` files is that
|
||||
sometimes they’re stored in text format for easy inspection. These versions
|
||||
usually have a ‘.pbtxt’ filename suffix, whereas the binary files end with
|
||||
‘.pb’.
|
||||
|
||||
- [FunctionDefLibrary](https://www.tensorflow.org/code/tensorflow/core/framework/function.proto):
|
||||
This appears in `GraphDef`, and is effectively a set of sub-graphs, each with
|
||||
information about their input and output nodes. Each sub-graph can then be
|
||||
used as an op in the main graph, allowing easy instantiation of different
|
||||
nodes, in a similar way to how functions encapsulate code in other languages.
|
||||
|
||||
- [MetaGraphDef](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto):
|
||||
A plain `GraphDef` only has information about the network of computations, but
|
||||
doesn’t have any extra information about the model or how it can be
|
||||
used. `MetaGraphDef` contains a `GraphDef` defining the computation part of
|
||||
the model, but also includes information like ‘signatures’, which are
|
||||
suggestions about which inputs and outputs you may want to call the model
|
||||
with, data on how and where any checkpoint files are saved, and convenience
|
||||
tags for grouping ops together for ease of use.
|
||||
|
||||
- [SavedModel](https://www.tensorflow.org/code/tensorflow/core/protobuf/saved_model.proto):
|
||||
It’s common to want to have different versions of a graph that rely on a
|
||||
common set of variable checkpoints. For example, you might need a GPU and a
|
||||
CPU version of the same graph, but keep the same weights for both. You might
|
||||
also need some extra files (like label names) as part of your
|
||||
model. The
|
||||
[SavedModel](https://www.tensorflow.org/code/tensorflow/python/saved_model/README.md) format
|
||||
addresses these needs by letting you save multiple versions of the same graph
|
||||
without duplicating variables, and also storing asset files in the same
|
||||
bundle. Under the hood, it uses `MetaGraphDef` and checkpoint files, along
|
||||
with extra metadata files. It’s the format that you’ll want to use if you’re
|
||||
deploying a web API using TensorFlow Serving, for example.
|
||||
|
||||
## How do you get a model you can use on mobile?
|
||||
|
||||
In most situations, training a model with TensorFlow will give you a folder
|
||||
containing a `GraphDef` file (usually ending with the `.pb` or `.pbtxt` extension) and
|
||||
a set of checkpoint files. What you need for mobile or embedded deployment is a
|
||||
single `GraphDef` file that’s been ‘frozen’, or had its variables converted into
|
||||
inline constants so everything’s in one file. To handle the conversion, you’ll
|
||||
need the `freeze_graph.py` script, that’s held in
|
||||
[`tensorflow/python/tools/freeze_graph.py`](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py). You’ll run it like this:
|
||||
|
||||
bazel build tensorflow/python/tools:freeze_graph
|
||||
bazel-bin/tensorflow/python/tools/freeze_graph \
|
||||
--input_graph=/tmp/model/my_graph.pb \
|
||||
--input_checkpoint=/tmp/model/model.ckpt-1000 \
|
||||
--output_graph=/tmp/frozen_graph.pb \
|
||||
--output_node_names=output_node \
|
||||
|
||||
The `input_graph` argument should point to the `GraphDef` file that holds your
|
||||
model architecture. It’s possible that your `GraphDef` has been stored in a text
|
||||
format on disk, in which case it’s likely to end in `.pbtxt` instead of `.pb`,
|
||||
and you should add an extra `--input_binary=false` flag to the command.
|
||||
|
||||
The `input_checkpoint` should be the most recent saved checkpoint. As mentioned
|
||||
in the checkpoint section, you need to give the common prefix to the set of
|
||||
checkpoints here, rather than a full filename.
|
||||
|
||||
`output_graph` defines where the resulting frozen `GraphDef` will be
|
||||
saved. Because it’s likely to contain a lot of weight values that take up a
|
||||
large amount of space in text format, it’s always saved as a binary protobuf.
|
||||
|
||||
`output_node_names` is a list of the names of the nodes that you want to extract
|
||||
the results of your graph from. This is needed because the freezing process
|
||||
needs to understand which parts of the graph are actually needed, and which are
|
||||
artifacts of the training process, like summarization ops. Only ops that
|
||||
contribute to calculating the given output nodes will be kept. If you know how
|
||||
your graph is going to be used, these should just be the names of the nodes you
|
||||
pass into `Session::Run()` as your fetch targets. The easiest way to find the
|
||||
node names is to inspect the Node objects while building your graph in python.
|
||||
Inspecting your graph in TensorBoard is another simple way. You can get some
|
||||
suggestions on likely outputs by running the [`summarize_graph` tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/README.md#inspecting-graphs).
|
||||
|
||||
Because the output format for TensorFlow has changed over time, there are a
|
||||
variety of other less commonly used flags available too, like `input_saver`, but
|
||||
hopefully you shouldn’t need these on graphs trained with modern versions of the
|
||||
framework.
|
||||
|
||||
## Using the Graph Transform Tool
|
||||
|
||||
A lot of the things you need to do to efficiently run a model on device are
|
||||
available through the [Graph Transform
|
||||
Tool](https://www.tensorflow.org/code/tensorflow/tools/graph_transforms/README.md). This
|
||||
command-line tool takes an input `GraphDef` file, applies the set of rewriting
|
||||
rules you request, and then writes out the result as a `GraphDef`. See the
|
||||
documentation for more information on how to build and run this tool.
|
||||
|
||||
### Removing training-only nodes
|
||||
|
||||
TensorFlow `GraphDefs` produced by the training code contain all of the
|
||||
computation that’s needed for back-propagation and updates of weights, as well
|
||||
as the queuing and decoding of inputs, and the saving out of checkpoints. All of
|
||||
these nodes are no longer needed during inference, and some of the operations
|
||||
like checkpoint saving aren’t even supported on mobile platforms. To create a
|
||||
model file that you can load on devices you need to delete those unneeded
|
||||
operations by running the `strip_unused_nodes` rule in the Graph Transform Tool.
|
||||
|
||||
The trickiest part of this process is figuring out the names of the nodes you
|
||||
want to use as inputs and outputs during inference. You'll need these anyway
|
||||
once you start to run inference, but you also need them here so that the
|
||||
transform can calculate which nodes are not needed on the inference-only
|
||||
path. These may not be obvious from the training code. The easiest way to
|
||||
determine the node name is to explore the graph with TensorBoard.
|
||||
|
||||
Remember that mobile applications typically gather their data from sensors and
|
||||
have it as arrays in memory, whereas training typically involves loading and
|
||||
decoding representations of the data stored on disk. In the case of Inception v3
|
||||
for example, there’s a `DecodeJpeg` op at the start of the graph that’s designed
|
||||
to take JPEG-encoded data from a file retrieved from disk and turn it into an
|
||||
arbitrary-sized image. After that there’s a `BilinearResize` op to scale it to
|
||||
the expected size, followed by a couple of other ops that convert the byte data
|
||||
into float and scale the value magnitudes it in the way the rest of the graph
|
||||
expects. A typical mobile app will skip most of these steps because it’s getting
|
||||
its input directly from a live camera, so the input node you will actually
|
||||
supply will be the output of the `Mul` node in this case.
|
||||
|
||||
<img src ="../images/inception_input.png" width="300">
|
||||
|
||||
You’ll need to do a similar process of inspection to figure out the correct
|
||||
output nodes.
|
||||
|
||||
If you’ve just been given a frozen `GraphDef` file, and are not sure about the
|
||||
contents, try using the `summarize_graph` tool to print out information
|
||||
about the inputs and outputs it finds from the graph structure. Here’s an
|
||||
example with the original Inception v3 file:
|
||||
|
||||
bazel run tensorflow/tools/graph_transforms:summarize_graph --
|
||||
--in_graph=tensorflow_inception_graph.pb
|
||||
|
||||
Once you have an idea of what the input and output nodes are, you can feed them
|
||||
into the graph transform tool as the `--input_names` and `--output_names`
|
||||
arguments, and call the `strip_unused_nodes` transform, like this:
|
||||
|
||||
bazel run tensorflow/tools/graph_transforms:transform_graph --
|
||||
--in_graph=tensorflow_inception_graph.pb
|
||||
--out_graph=optimized_inception_graph.pb --inputs='Mul' --outputs='softmax'
|
||||
--transforms='
|
||||
strip_unused_nodes(type=float, shape="1,299,299,3")
|
||||
fold_constants(ignore_errors=true)
|
||||
fold_batch_norms
|
||||
fold_old_batch_norms'
|
||||
|
||||
One thing to look out for here is that you need to specify the size and type
|
||||
that you want your inputs to be. This is because any values that you’re going to
|
||||
be passing in as inputs to inference need to be fed to special `Placeholder` op
|
||||
nodes, and the transform may need to create them if they don’t already exist. In
|
||||
the case of Inception v3 for example, a `Placeholder` node replaces the old
|
||||
`Mul` node that used to output the resized and rescaled image array, since we’re
|
||||
going to be doing that processing ourselves before we call TensorFlow. It keeps
|
||||
the original name though, which is why we always feed in inputs to `Mul` when we
|
||||
run a session with our modified Inception graph.
|
||||
|
||||
After you’ve run this process, you’ll have a graph that only contains the actual
|
||||
nodes you need to run your prediction process. This is the point where it
|
||||
becomes useful to run metrics on the graph, so it’s worth running
|
||||
`summarize_graph` again to understand what’s in your model.
|
||||
|
||||
## What ops should you include on mobile?
|
||||
|
||||
There are hundreds of operations available in TensorFlow, and each one has
|
||||
multiple implementations for different data types. On mobile platforms, the size
|
||||
of the executable binary that’s produced after compilation is important, because
|
||||
app download bundles need to be as small as possible for the best user
|
||||
experience. If all of the ops and data types are compiled into the TensorFlow
|
||||
library then the total size of the compiled library can be tens of megabytes, so
|
||||
by default only a subset of ops and data types are included.
|
||||
|
||||
That means that if you load a model file that’s been trained on a desktop
|
||||
machine, you may see the error “No OpKernel was registered to support Op” when
|
||||
you load it on mobile. The first thing to try is to make sure you’ve stripped
|
||||
out any training-only nodes, since the error will occur at load time even if the
|
||||
op is never executed. If you’re still hitting the same problem once that’s done,
|
||||
you’ll need to look at adding the op to your built library.
|
||||
|
||||
The criteria for including ops and types fall into several categories:
|
||||
|
||||
- Are they only useful in back-propagation, for gradients? Since mobile is
|
||||
focused on inference, we don’t include these.
|
||||
|
||||
- Are they useful mainly for other training needs, such as checkpoint saving?
|
||||
These we leave out.
|
||||
|
||||
- Do they rely on frameworks that aren’t always available on mobile, such as
|
||||
libjpeg? To avoid extra dependencies we don’t include ops like `DecodeJpeg`.
|
||||
|
||||
- Are there types that aren’t commonly used? We don’t include boolean variants
|
||||
of ops for example, since we don’t see much use of them in typical inference
|
||||
graphs.
|
||||
|
||||
These ops are trimmed by default to optimize for inference on mobile, but it is
|
||||
possible to alter some build files to change the default. After alternating the
|
||||
build files, you will need to recompile TensorFlow. See below for more details
|
||||
on how to do this, and also see <a href="./optimizing.md">optimizing binary size</a>
|
||||
for more on reducing your binary size.
|
||||
|
||||
### Locate the implementation
|
||||
|
||||
Operations are broken into two parts. The first is the op definition, which
|
||||
declares the signature of the operation, which inputs, outputs, and attributes
|
||||
it has. These take up very little space, and so all are included by default. The
|
||||
implementations of the op computations are done in kernels, which live in the
|
||||
`tensorflow/core/kernels` folder. You need to compile the C++ file containing
|
||||
the kernel implementation of the op you need into the library. To figure out
|
||||
which file that is, you can search for the operation name in the source
|
||||
files.
|
||||
|
||||
[Here’s an example search in github](https://github.com/search?utf8=%E2%9C%93&q=repo%3Atensorflow%2Ftensorflow+extension%3Acc+path%3Atensorflow%2Fcore%2Fkernels+REGISTER+Mul&type=Code&ref=searchresults).
|
||||
|
||||
You’ll see that this search is looking for the `Mul` op implementation, and it
|
||||
finds it in `tensorflow/core/kernels/cwise_op_mul_1.cc`. You need to look for
|
||||
macros beginning with `REGISTER`, with the op name you care about as one of the
|
||||
string arguments.
|
||||
|
||||
In this case, the implementations are actually broken up across multiple `.cc`
|
||||
files, so you’d need to include all of them in your build. If you’re more
|
||||
comfortable using the command line for code search, here’s a grep command that
|
||||
also locates the right files if you run it from the root of your TensorFlow
|
||||
repository:
|
||||
|
||||
`grep 'REGISTER.*"Mul"' tensorflow/core/kernels/*.cc`
|
||||
|
||||
### Add the implementation to the build
|
||||
|
||||
If you’re using Bazel, and building for Android, you’ll want to add the files
|
||||
you’ve found to
|
||||
the
|
||||
[`android_extended_ops_group1`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3565) or
|
||||
[`android_extended_ops_group2`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3632) targets. You
|
||||
may also need to include any .cc files they depend on in there. If the build
|
||||
complains about missing header files, add the .h’s that are needed into
|
||||
the
|
||||
[`android_extended_ops`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3525) target.
|
||||
|
||||
If you’re using a makefile targeting iOS, Raspberry Pi, etc, go to
|
||||
[`tensorflow/contrib/makefile/tf_op_files.txt`](https://www.tensorflow.org/code/tensorflow/contrib/makefile/tf_op_files.txt) and
|
||||
add the right implementation files there.
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user