Merge pull request #26326 from tensorflow/2.0-ff

2.0 ff , Move r2.0 branch ahaed to pick up test and build fixes.
This commit is contained in:
Gunhan Gulsoy 2019-03-04 10:26:28 -08:00 committed by GitHub
commit bdecee4c43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
140 changed files with 3549 additions and 3303 deletions

View File

@ -39,14 +39,19 @@ filegroup(
"python_api.h",
"*test*",
],
),
) + [
"//tensorflow/cc:srcs",
"//tensorflow/core/distributed_runtime:server_lib.h",
],
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api.h"],
hdrs = ["c_api_internal.h"],
hdrs = [
"c_api.h",
"c_api_internal.h",
],
visibility = [
"//tensorflow:internal",
"//tensorflow/c:__subpackages__",
@ -68,7 +73,9 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api",
hdrs = ["c_api.h"],
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
@ -89,9 +96,7 @@ tf_cuda_library(
"c_api.cc",
"c_api_function.cc",
],
hdrs = [
"c_api.h",
],
hdrs = ["c_api.h"],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [":c_api_internal"] + select({

View File

@ -8,6 +8,19 @@ package(
licenses(["notice"]) # Apache 2.0
filegroup(
name = "srcs",
srcs = [
"framework/gradients.h",
"framework/ops.h",
"framework/scope.h",
"framework/scope_internal.h",
"ops/array_ops.h",
"ops/while_loop.h",
"//tensorflow/cc/saved_model:loader.h",
],
)
load(
"//tensorflow:tensorflow.bzl",
"cc_library_with_android_deps",
@ -606,16 +619,13 @@ tf_gen_op_wrappers_cc(
visibility = ["//tensorflow:internal"],
)
cc_library_with_android_deps(
cc_library(
name = "cc_op_gen_main",
srcs = [
"framework/cc_op_gen.cc",
"framework/cc_op_gen.h",
"framework/cc_op_gen_main.cc",
],
android_deps = [
"//tensorflow/core:android_tensorflow_lib",
],
copts = tf_copts(),
data = [
"//tensorflow/core/api_def:base_api_def",

View File

@ -1432,6 +1432,30 @@ Status CheckInputsWeights(
return Status::OK();
}
Status AllowDataTypes(const OpConverterParams& params,
const std::set<DataType>& allowed_dtypes) {
const auto& node_def = params.node_def;
TFAttrs attrs(params.node_def);
if (attrs.count("T")) {
const auto op_dtype = attrs.get<DataType>("T");
if (!allowed_dtypes.count(op_dtype)) {
// Build string list of allowed types.
std::ostringstream ss;
for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) {
if (it != allowed_dtypes.begin()) ss << ", ";
ss << DataTypeString(*it);
}
return errors::Unimplemented("Data type ", DataTypeString(op_dtype),
" is not supported for ", node_def.op(),
", must be one of [", ss.str(), "], at ",
node_def.name());
}
}
// If there is no T attribute, we can't determine the type of the op. We will
// allow it to convert for now.
return Status::OK();
}
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
const TRT_ShapedWeights& weights_src) {
auto dtype_new = DataType::DT_HALF;
@ -1734,6 +1758,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
CheckInputsWeights(*params, {{"input", false}, {"filter", true}}));
tensor = inputs.at(0).tensor();
}
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
if (weights_rsck.shape_.nbDims != 4) {
return errors::InvalidArgument("Conv2D expects kernel of dimension 4, at " +
@ -1996,6 +2022,8 @@ Status ConvertTranspose(OpConverterParams* params) {
const auto& inputs = params->inputs;
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"x", false}, {"perm", true}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
// Get the permutation from weights.
TRT_ShapedWeights weights = inputs.at(1).weights();
const int* weights_ptr =
@ -2030,6 +2058,8 @@ Status ConvertReshape(OpConverterParams* params) {
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
TRT_TensorOrWeights input_tensor = inputs.at(0);
TRT_ShapedWeights weights = inputs.at(1).weights();
if (weights.count() == 0) {
@ -2127,6 +2157,8 @@ Status ConvertExpandDims(OpConverterParams* params) {
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
// Get input shape as vector.
TRT_TensorOrWeights input_tensor = inputs.at(0);
const nvinfer1::Dims dims = input_tensor.GetTrtDims();
@ -2177,6 +2209,8 @@ Status ConvertSqueeze(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
// Get input shape.
TRT_TensorOrWeights input_tensor = inputs.at(0);
const nvinfer1::Dims dims = input_tensor.GetTrtDims();
@ -2439,6 +2473,8 @@ Status ConvertSlice(OpConverterParams* params) {
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(
*params, {{"input", false}, {"begin", true}, {"size", true}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
std::vector<int> begin = inputs.at(1).weights().ToVector<int>();
std::vector<int> size = inputs.at(2).weights().ToVector<int>();
// Get input dims.
@ -2483,6 +2519,8 @@ Status ConvertStridedSlice(OpConverterParams* params) {
TF_RETURN_IF_ERROR(CheckInputsWeights(
*params,
{{"input", false}, {"begin", true}, {"end", true}, {"strides", true}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
// Get input dims.
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
@ -2578,6 +2616,8 @@ Status ConvertPool(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
nvinfer1::PoolingType type;
if (node_def.op() == "MaxPool") {
type = nvinfer1::PoolingType::kMAX;
@ -2669,15 +2709,10 @@ Status ConvertPool(OpConverterParams* params) {
Status ConvertLeakyRelu(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
if (inputs.size() != 1) {
return errors::InvalidArgument(node_def.op(), " expects one input, at ",
node_def.name());
}
if (!inputs.at(0).is_tensor()) {
return errors::Unimplemented(node_def.op(),
" is only implemented for tensors, at ",
node_def.name());
}
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
TFAttrs attrs(node_def);
const float alpha = attrs.get<float>("alpha");
if (alpha < 0.0f || alpha > 1.0f) {
@ -2719,6 +2754,8 @@ Status ConvertActivation(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
static const std::unordered_map<string, nvinfer1::ActivationType> ops{
{"Relu", nvinfer1::ActivationType::kRELU},
{"Sigmoid", nvinfer1::ActivationType::kSIGMOID},
@ -2815,6 +2852,8 @@ Status ConvertRelu6(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
if (params->validation_only) return Status::OK();
// ***************************************************************************
// TensorRT does not implement Relu6 natively. This function converts Relu6 op
@ -2864,18 +2903,14 @@ Status ConvertBiasAdd(OpConverterParams* params) {
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"value", false}, {"bias", true}}));
TFAttrs attrs(node_def);
DataType tf_dtype = attrs.get<DataType>("T");
if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) {
return errors::Unimplemented("Data type is not supported, for node ",
node_def.name(), " got ",
DataTypeString(tf_dtype));
}
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
if (params->validation_only) return Status::OK();
nvinfer1::ITensor* tensor =
const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor());
const nvinfer1::Dims original_dims = tensor->getDimensions();
TFAttrs attrs(node_def);
const string data_format = attrs.get<string>("data_format");
const int channel_index =
(data_format == "NHWC" ? original_dims.nbDims - 1 : 0);
@ -3092,6 +3127,8 @@ Status ConvertBinary(OpConverterParams* params) {
" inputs but expected 2, at ",
node_def.name());
}
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
// Constant folding should have been done by TensorFlow
if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
@ -3133,6 +3170,8 @@ Status ConvertRsqrt(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
if (params->validation_only) return Status::OK();
// TODO(tmorris): params->converter is null during validation. Allow
@ -3199,6 +3238,8 @@ Status ConvertUnary(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
auto op_pair = UnaryOperationMap()->find(node_def.op());
if (op_pair == UnaryOperationMap()->end()) {
return errors::Unimplemented("Unary op: ", node_def.op(),
@ -3236,27 +3277,21 @@ Status ConvertSquare(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
if (params->validation_only) return Status::OK();
// Constant 2 with same rank as input
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
for (int i = 0; i < dims.nbDims; i++) {
dims.d[i] = 1;
}
TRT_ShapedWeights weights =
params->weight_store->GetTempWeights(DataType::DT_FLOAT, dims);
auto weights_ptr =
static_cast<float*>(const_cast<void*>(weights.GetValues()));
weights_ptr[0] = 2.f;
nvinfer1::ITensor* const2_tensor =
params->converter->CreateConstantLayer(weights, dims);
TFTRT_RETURN_ERROR_IF_NULLPTR(const2_tensor, node_def.name());
const nvinfer1::ITensor* const2_tensor = nullptr;
TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
params, 2.0f, inputs.at(0).GetTrtDims(), &const2_tensor));
// ElementWise Pow Operation
nvinfer1::IElementWiseLayer* layer =
params->converter->network()->addElementWise(
*const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()),
*const2_tensor, nvinfer1::ElementWiseOperation::kPOW);
*const_cast<nvinfer1::ITensor*>(const2_tensor),
nvinfer1::ElementWiseOperation::kPOW);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
@ -3269,6 +3304,8 @@ Status ConvertReduce(OpConverterParams* params) {
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
TRT_ShapedWeights index_list = inputs.at(1).weights();
@ -3330,6 +3367,8 @@ Status ConvertPad(OpConverterParams* params) {
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
// Implement tensor binaryOp weight [channel wise] for now;
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
@ -3432,6 +3471,10 @@ Status ConvertPad(OpConverterParams* params) {
Status ConvertConcat(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
// TODO(tmorris): There is a bug with Concat and INT32 in TRT - it is supposed
// to be supported.
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
// not including the last input (axis) here
int input_size = static_cast<int>(inputs.size()) - 1;
@ -3514,6 +3557,8 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) {
{"offset", true},
{"mean", true},
{"variance", true}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
TFAttrs attrs(node_def);
float epsilon = attrs.get<float>("epsilon");
auto data_format = attrs.get<string>("data_format");
@ -3645,6 +3690,8 @@ Status ConvertGather(OpConverterParams* params) {
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(
*params, {{"params", false}, {"indices", false}, {"axis", true}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
if (axis.size() != 1) {
return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",
@ -3714,14 +3761,10 @@ Status ConvertMatMul(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"a", false}, {"b", true}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
TFAttrs attrs(node_def);
DataType tf_dtype = attrs.get<DataType>("T");
if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) {
return errors::Unimplemented("Data type is not supported, for node ",
node_def.name(), " got ",
DataTypeString(tf_dtype));
}
bool transpose_a = attrs.get<bool>("transpose_a");
bool transpose_b = attrs.get<bool>("transpose_b");
@ -3742,6 +3785,8 @@ Status ConvertBatchMatMul(OpConverterParams* params) {
// TODO(tmorris): Enable once false is updated to mean either tensor or weight
// TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
// false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
if (inputs.size() != 2) {
return errors::InvalidArgument(node_def.op(), " got ", inputs.size(),
" inputs but expected 2, at ",
@ -3752,14 +3797,6 @@ Status ConvertBatchMatMul(OpConverterParams* params) {
"All inputs are weights, but Grappler is expected to fold them.");
}
TFAttrs attrs(node_def);
const DataType tf_dtype = attrs.get<DataType>("T");
if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) {
return errors::Unimplemented("data type is not supported, for node ",
node_def.name(),
" got " + DataTypeString(tf_dtype));
}
const bool transpose_a = attrs.get<bool>("adj_x");
const bool transpose_b = attrs.get<bool>("adj_y");
const auto dims = inputs.at(0).GetTrtDims();
@ -3815,6 +3852,8 @@ Status ConvertSoftmax(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}}));
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
int nbDims = tensor->getDimensions().nbDims;
@ -3840,15 +3879,11 @@ Status ConvertSoftmax(OpConverterParams* params) {
Status ConvertTopK(OpConverterParams* params) {
const auto& inputs = params->inputs;
if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
!inputs.at(1).is_weights()) {
return errors::InvalidArgument("Input expects tensor and weights, at ",
params->node_def.name());
}
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(
CheckInputsWeights(*params, {{"input", false}, {"k", true}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
const int num_dims = tensor->getDimensions().nbDims;
if (num_dims == 0) {

View File

@ -1571,9 +1571,9 @@ TEST_F(OpConverterTest, ConvertMatMul) {
NodeDef node_def = get_matmul_nodedef(DT_INT32, false, false);
AddTestTensor("input", {2}, /*batch_size=*/1, nvinfer1::DataType::kINT32);
AddTestWeights<int32>("weights", {2, 1}, {3, 5});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Data type is not supported, for node my_matmul got int32");
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"Data type int32 is not supported for MatMul, "
"must be one of [float, half], at my_matmul");
}
// transpose_a is set.
for (bool transpose_b : {false, true}) {
@ -3328,8 +3328,9 @@ TEST_F(OpConverterTest, ConvertTopK) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_topk", "TopKV2", {});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Input expects tensor and weights, at my_topk");
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"TopKV2 got 0 inputs but expected 2, at my_topk");
}
for (const auto dtype : {DT_FLOAT, DT_INT32}) {
@ -3346,8 +3347,8 @@ TEST_F(OpConverterTest, ConvertTopK) {
/*trt_dtype=*/TfDataTypeToTrt(dtype));
AddTestTensor("weights", {2});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Input expects tensor and weights, at my_topk");
node_def, error::UNIMPLEMENTED,
"The input \"k\" for TopKV2 must be a constant, at my_topk");
}
{
// Ok.

View File

@ -79,7 +79,10 @@ static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
XLA_MAKE_BINARY(DivNoNan,
DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
// Implementation of FloorDiv. Pseudo-code:
// Implementation of FloorDiv.
//
// For floating-point values, simply returns floor(x / y). For integers, does:
//
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);
// T abs_y = std::abs(y);
@ -90,6 +93,9 @@ XLA_MAKE_BINARY(DivNoNan,
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
if (DataTypeIsFloating(dtype)) {
return xla::Floor(xla::Div(x, y));
}
if (DataTypeIsUnsigned(dtype)) {
return xla::Div(x, y);
}
@ -99,11 +105,7 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
auto abs_x = xla::Abs(x);
auto abs_y = xla::Abs(y);
auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one));
auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
if (DataTypeIsFloating(dtype)) {
result = xla::Floor(result);
}
return result;
return xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
}
XLA_MAKE_BINARY(FloorDiv,
FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));

View File

@ -43,7 +43,9 @@ namespace {
// A * H or H * A zeros out trailing part of some row or column of A.
//
// [x0, ..., x_{k-1}, xk, x_{k+1}, ..., x_{n-1}] * H
// = [x0, ..., x_{k-1}, vnorm, 0, ..., 0]
// = [x0, ..., x_{k-1}, xnorm, 0, ..., 0]
//
// Here xnorm = norm([x_k, x_{k+1}, ..., x_{n - 1}])
struct HouseHolderResult {
XlaOp v;
XlaOp beta;
@ -82,7 +84,7 @@ struct FrobeniusNorms {
//
// H = I - beta * [1, v]' * [1, v]
//
// H * x = [..., sigma, 0, ..., 0]
// H * x = [..., xnorm, 0, ..., 0]
// ..., j, j + 1, ..., n
//
// def house(x, j, eps):
@ -161,21 +163,10 @@ StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
HouseHolderResult result;
result.v = v;
result.beta = beta;
a = Sub(a, Mul(beta, BatchDot(BatchDot(a, TransposeInMinorDims(v), precision),
result.a =
Sub(a, Mul(beta, BatchDot(BatchDot(a, TransposeInMinorDims(v), precision),
v, precision)));
auto xnorm =
Sqrt(Reduce(Square(Select(Ge(idx, j), x, zeros)), ScalarLike(x, 0.0),
CreateScalarAddComputation(x_shape.element_type(), builder),
{num_dims - 1}));
xnorm = BroadcastInDim(xnorm, x_shape.dimensions(), broadcast_dims);
x = Select(Lt(idx, j), x, zeros);
x = Select(Eq(idx, j), xnorm, x);
result.a = DynamicUpdateSliceInMinorDims(a, x, {i, zero});
return result;
}
@ -184,7 +175,7 @@ StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
//
// H = I - beta * [1; v] * [1; v]', then,
//
// H * A[i:, j] = [sigma, 0, 0, ..., 0]
// H * A[i:, j] = [xnorm, 0, 0, ..., 0]
//
StatusOr<HouseHolderResult> HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
PrecisionConfig::Precision precision) {
@ -239,21 +230,9 @@ StatusOr<HouseHolderResult> HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
HouseHolderResult result;
result.v = v;
result.beta = beta;
a = Sub(a,
Mul(beta, BatchDot(v, BatchDot(TransposeInMinorDims(v), a, precision),
precision)));
auto xnorm =
Sqrt(Reduce(Square(Select(Ge(idx, i), x, zeros)), ScalarLike(x, 0.0),
CreateScalarAddComputation(x_shape.element_type(), builder),
{num_dims - 2}));
xnorm = BroadcastInDim(xnorm, x_shape.dimensions(), broadcast_dims);
x = Select(Lt(idx, i), x, zeros);
x = Select(Eq(idx, i), xnorm, x);
result.a = DynamicUpdateSliceInMinorDims(a, x, {zero, j});
result.a = Sub(
a, Mul(beta, BatchDot(v, BatchDot(TransposeInMinorDims(v), a, precision),
precision)));
return result;
}
@ -774,7 +753,7 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
auto zero = Zero(builder, S32);
// As m >= n, only first m columns vectors are needed to be permuted, and the
// rest of n - m vectors are appended after the sorting is done.
// rest of m - n vectors are appended after the sorting is done.
XlaOp sort_u_result =
Sort({-d, DynamicSliceInMinorDims(result.u, {zero, zero}, {m, n})},
CreateScalarLtComputation(
@ -799,7 +778,7 @@ StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
{num_dims - 2})),
broadcast_dims);
// Append the rest of n - m vectors.
// Append the rest of m - n vectors.
result.u =
ConcatInDim(builder,
{GetTupleElement(sort_u_result, 1),

View File

@ -131,6 +131,24 @@ class SVDTest : public ClientLibraryTestBase {
Array3D<float> batch_3d_4x5_;
};
XLA_TEST_F(SVDTest, Simple2D) {
XlaBuilder builder(TestName());
Array2D<float> simple_2d_4x4_ = Array2D<float>{
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
};
XlaOp a;
auto a_data = CreateR2Parameter<float>(simple_2d_4x4_, 0, "a", &builder, &a);
auto result = SVD(a, 100, 1e-6);
ComputeMatmulUDVT(result, &builder);
ComputeAndCompareR2<float>(&builder, simple_2d_4x4_, {a_data.get()},
ErrorSpec(1e-3, 1e-3));
}
XLA_TEST_F(SVDTest, Test_VWVt_EQ_A_2x4x5) {
XlaBuilder builder(TestName());

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_
#include "google/protobuf/duration.pb.h"
#include "absl/time/time.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -44,6 +45,20 @@ Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
// dirpath along as-is.
void RegisterDirectoryExpander(const std::function<string(string)>& expander);
// Converts an absl::Duration to a google::protobuf::Duration.
inline google::protobuf::Duration ToDurationProto(absl::Duration duration) {
google::protobuf::Duration proto;
proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration));
proto.set_nanos(
absl::IDivDuration(duration, absl::Nanoseconds(1), &duration));
return proto;
}
// Converts a google::protobuf::Duration to an absl::Duration.
inline absl::Duration FromDurationProto(google::protobuf::Duration proto) {
return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
}
} // namespace protobuf_util
} // namespace xla

File diff suppressed because it is too large Load Diff

View File

@ -440,14 +440,15 @@ cc_library(
srcs = ["cudnn_conv_algorithm_picker.cc"],
hdrs = ["cudnn_conv_algorithm_picker.h"],
deps = [
":autotuning_proto",
":backend_configs",
":buffer_comparator",
":cudnn_conv_runner",
":gpu_autotuning_proto",
":gpu_executable",
":ir_emission_utils",
":scratch_allocator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
@ -455,9 +456,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:logger",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/util/proto:proto_utils",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
@ -777,7 +776,6 @@ cc_library(
hdrs = ["gpu_transfer_manager.h"],
deps = [
":gpu_compiler",
":infeed_manager",
":outfeed_manager",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
@ -790,6 +788,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:generic_transfer_manager",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
@ -1138,8 +1137,8 @@ tf_cc_test(
srcs = ["cudnn_fused_conv_rewriter_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
@ -1184,11 +1183,10 @@ tf_cc_test(
)
xla_proto_library(
name = "gpu_autotuning_proto",
srcs = ["gpu_autotuning.proto"],
name = "autotuning_proto",
srcs = ["autotuning.proto"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:protos_all_cc",
],
)

View File

@ -1,14 +1,15 @@
// This file defines protos that store the results of autotuning various
// This file defines protos that store the results of autotuning XLA:GPU
// operations.
//
// They are in proto format because we want to log them structured. They offer
// tremendous statistical, testing, and debugging value.
syntax = "proto3";
package tensorflow;
package xla.gpu;
import "google/protobuf/any.proto";
import "google/protobuf/duration.proto";
import "tensorflow/compiler/xla/xla_data.proto";
import "tensorflow/compiler/xla/service/hlo.proto";
message CudnnVersion {
int32 major = 1;
@ -62,12 +63,19 @@ message AutotuneResult {
}
}
message AutotuningLog {
google.protobuf.Any instr = 1;
message AutotuneLog {
message Instruction {
xla.HloInstructionProto instruction = 1;
repeated xla.ShapeProto operand_shapes = 2;
}
oneof instr_oneof {
Instruction instr = 1;
}
// Records all auto-tuning results per algorithm.
repeated AutotuneResult results = 2;
repeated AutotuneResult results = 3;
CudnnVersion cudnn_version = 3;
ComputeCapability compute_capability = 4;
CudnnVersion cudnn_version = 4;
ComputeCapability compute_capability = 5;
}

View File

@ -14,23 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
#include "google/protobuf/any.pb.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/time/time.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logger.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/proto/proto_utils.h"
namespace xla {
namespace gpu {
@ -39,7 +37,6 @@ namespace {
using absl::optional;
using se::DeviceMemoryBase;
using se::dnn::AlgorithmDesc;
using tensorflow::AutotuneResult;
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
se::StreamExecutor* stream_exec) {
@ -97,8 +94,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
return tensorflow::mutex_lock{it->second};
}
tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
tensorflow::CudnnVersion cudnn_version;
xla::gpu::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
xla::gpu::CudnnVersion cudnn_version;
if (auto* dnn = stream_executor->AsDnn()) {
StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
if (version_or.ok()) {
@ -111,9 +108,9 @@ tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
return cudnn_version;
}
tensorflow::ComputeCapability GetComputeCapability(
xla::gpu::ComputeCapability GetComputeCapability(
se::StreamExecutor* stream_executor) {
tensorflow::ComputeCapability cc;
xla::gpu::ComputeCapability cc;
int cc_major, cc_minor;
stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
&cc_minor);
@ -246,23 +243,25 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
&scratch_allocator, &stream, options);
if (!launch_status.ok()) {
continue;
}
if (!profile_result.is_valid()) {
continue;
}
profile_results.emplace_back();
AutotuneResult& result = profile_results.back();
result.mutable_conv()->set_algorithm(alg.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
if (!launch_status.ok()) {
result.set_error_string(launch_status.error_message());
continue;
}
if (!profile_result.is_valid()) {
result.set_error_string("Invalid profile result");
continue;
}
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
result.mutable_success()->set_scratch_bytes(scratch_bytes_used);
*result.mutable_success()->mutable_run_time() =
tensorflow::proto_utils::ToDurationProto(
protobuf_util::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
const bool crash_on_checking_failure =
@ -309,14 +308,10 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
// Log the autotuning result.
{
tensorflow::AutotuningLog log;
{
ConvInstructionLog instr_log;
*instr_log.mutable_instruction() = instr->ToProto();
for (const auto* op : instr->operands()) {
*instr_log.add_operand_shapes() = op->shape().ToProto();
}
log.mutable_instr()->PackFrom(instr_log);
AutotuneLog log;
*log.mutable_instr()->mutable_instruction() = instr->ToProto();
for (const auto* op : instr->operands()) {
*log.mutable_instr()->add_operand_shapes() = op->shape().ToProto();
}
for (const auto& profile : profile_results) {
*log.add_results() = profile;
@ -335,14 +330,13 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
// The successful one should have a smaller key, since we are doing
// min_element. If they are both unsuccessful, keep the earlier one in
// the vector by comparing pointers.
return std::make_tuple(!lhs.has_success(),
tensorflow::proto_utils::FromDurationProto(
lhs.success().run_time()),
&lhs) <
std::make_tuple(!rhs.has_success(),
tensorflow::proto_utils::FromDurationProto(
rhs.success().run_time()),
&rhs);
return std::make_tuple(
!lhs.has_success(),
protobuf_util::FromDurationProto(lhs.success().run_time()),
&lhs) < std::make_tuple(!rhs.has_success(),
protobuf_util::FromDurationProto(
rhs.success().run_time()),
&rhs);
});
if (best_result != profile_results_end && best_result->has_success()) {

View File

@ -20,12 +20,12 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/autotuning.pb.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/protobuf/autotuning.pb.h"
namespace xla {
namespace gpu {
@ -50,7 +50,7 @@ class CudnnConvAlgorithmPicker : public HloModulePass {
private:
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithm(
StatusOr<AutotuneResult> PickBestAlgorithm(
const HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null

View File

@ -1,13 +0,0 @@
// This is used for convolution logging. Also see
// tensorflow/core/protobuf/autotuing.h
syntax = "proto3";
package xla.gpu;
import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/compiler/xla/xla_data.proto";
message ConvInstructionLog {
xla.HloInstructionProto instruction = 1;
repeated xla.ShapeProto operand_shapes = 2;
}

View File

@ -1742,7 +1742,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
dnums.DebugString());
}
if (kernel_output_features % feature_group_count > 0) {
// A depthwise/grouped filter has the shape
// [space0, .. spaceN, GROUP_SIZE, NUM_OUTPUT_FEATURES]. When
// [space0, .. spaceN, GROUP_SIZE] is convolved with the input, a shape
// [space0, .. spaceN, feature_group_count] is formed. Therefore, the output
// feature count (which is equal to kernel output features) has to be a
// multiple of feature_group_count.
return InvalidArgument(
"Expected output feature dimension (value %d) to be divisible by "
"feature_group_count (value %d); "

View File

@ -629,6 +629,9 @@ BENCHMARK_NAME := $(BINDIR)benchmark
CORE_CC_ALL_SRCS := \
$(ABSL_CC_SRCS) \
tensorflow/c/c_api.cc \
tensorflow/c/kernels.cc \
tensorflow/c/tf_status_helper.cc \
$(wildcard tensorflow/core/*.cc) \
$(wildcard tensorflow/core/common_runtime/*.cc) \
$(wildcard tensorflow/core/framework/*.cc) \

View File

@ -13,11 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.ops.attention_wrapper."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import collections
import functools
@ -30,6 +28,7 @@ from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@ -66,6 +65,7 @@ def get_result_summary(x):
return x
@test_util.run_v1_only
class AttentionWrapperTest(test.TestCase):
def assertAllCloseOrEqual(self, x, y, **kwargs):

View File

@ -30,7 +30,6 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.keras import initializers
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@ -305,7 +304,10 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
attention_layer_size = attention_layer_size[0]
if attention_layer is not None:
attention_layer = attention_layer[0]
cell = rnn_cell.LSTMCell(cell_depth, initializer="ones")
cell = keras.layers.LSTMCell(cell_depth,
recurrent_activation="sigmoid",
kernel_initializer="ones",
recurrent_initializer="ones")
cell = wrapper.AttentionWrapper(
cell,
attention_mechanisms if is_multi else attention_mechanisms[0],
@ -321,7 +323,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sampler = sampler_py.TrainingSampler()
my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler)
initial_state = cell.zero_state(
initial_state = cell.get_initial_state(
dtype=dtypes.float32, batch_size=batch_size)
final_outputs, final_state, _ = my_decoder(
decoder_inputs,
@ -330,7 +332,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
expected_time = (
expected_final_state.time if context.executing_eagerly() else None)
@ -342,9 +343,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
self.assertEqual((batch_size, attention_depth),
tuple(final_state.attention.get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.c.get_shape().as_list()))
tuple(final_state.cell_state[0].get_shape().as_list()))
self.assertEqual((batch_size, cell_depth),
tuple(final_state.cell_state.h.get_shape().as_list()))
tuple(final_state.cell_state[1].get_shape().as_list()))
if alignment_history:
if is_multi:
@ -395,8 +396,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
expected_final_alignment_history,
final_alignment_history_info)
@parameterized.parameters([np.float16, np.float32, np.float64])
def _testBahdanauNormalizedDType(self, dtype):
# TODO(b/126893309): reenable np.float16 once the bug is fixed.
@parameterized.parameters([np.float32, np.float64])
def testBahdanauNormalizedDType(self, dtype):
encoder_outputs = self.encoder_outputs.astype(dtype)
decoder_inputs = self.decoder_inputs.astype(dtype)
attention_mechanism = wrapper.BahdanauAttentionV2(
@ -405,7 +407,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
memory_sequence_length=self.encoder_sequence_length,
normalize=True,
dtype=dtype)
cell = rnn_cell.LSTMCell(self.units)
cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
sampler = sampler_py.TrainingSampler()
@ -418,9 +420,9 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
@parameterized.parameters([np.float16, np.float32, np.float64])
# TODO(b/126893309): reenable np.float16 once the bug is fixed.
@parameterized.parameters([np.float32, np.float64])
def testLuongScaledDType(self, dtype):
# Test case for GitHub issue 18099
encoder_outputs = self.encoder_outputs.astype(dtype)
@ -432,7 +434,7 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
scale=True,
dtype=dtype,
)
cell = rnn_cell.LSTMCell(self.units)
cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid")
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
sampler = sampler_py.TrainingSampler()
@ -445,7 +447,6 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput)
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple)
def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttentionV2
@ -455,11 +456,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324),
sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824),
ResultSummary(
shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569),
time=3,
@ -490,11 +491,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728),
time=3,
@ -520,11 +521,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=4.084631),
time=3,
@ -550,11 +551,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314),
time=3,
@ -581,11 +582,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002)],
attention=ResultSummary(
shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335),
time=3,
@ -613,11 +614,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186),
time=3,
@ -648,11 +649,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721),
time=3,
@ -682,11 +683,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
time=3,
@ -716,11 +717,11 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase):
sample_id=ResultSummary(
shape=(5, 3), dtype=np.dtype("int32"), mean=0.0))
expected_final_state = wrapper.AttentionWrapperState(
cell_state=rnn_cell.LSTMStateTuple(
c=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384),
h=ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)),
cell_state=[
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038),
ResultSummary(
shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384)],
attention=ResultSummary(
shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605),
time=3,

View File

@ -13,31 +13,30 @@
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core as layers_core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
@test_util.run_v1_only
class BasicDecoderTest(test.TestCase):
def _testStepWithTrainingHelper(self, use_output_layer):

View File

@ -187,14 +187,23 @@ class TestArrayShapeChecks(test.TestCase):
shape=dynamic_shape)
batch_size = array_ops.constant(batch_size)
check_op = beam_search_decoder._check_batch_beam(t, batch_size, beam_width) # pylint: disable=protected-access
with self.cached_session() as sess:
if is_valid:
sess.run(check_op)
def _test_body():
# pylint: disable=protected-access
if context.executing_eagerly():
beam_search_decoder._check_batch_beam(t, batch_size, beam_width)
else:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(check_op)
with self.cached_session():
check_op = beam_search_decoder._check_batch_beam(
t, batch_size, beam_width)
self.evaluate(check_op)
# pylint: enable=protected-access
if is_valid:
_test_body()
else:
with self.assertRaises(errors.InvalidArgumentError):
_test_body()
def test_array_shape_dynamic_checks(self):
self._test_array_shape_dynamic_checks(
@ -463,6 +472,7 @@ class TestLargeBeamStep(test.TestCase):
self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
@test_util.run_v1_only
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention,

View File

@ -49,8 +49,8 @@ class GatherTreeTest(test.TestCase):
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
with self.session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
with self.cached_session(use_gpu=True):
self.assertAllEqual(expected_result, self.evaluate(beams))
def testBadParentValuesOnCPU(self):
# (batch_size = 1, max_time = 4, beams = 3)
@ -62,15 +62,14 @@ class GatherTreeTest(test.TestCase):
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
max_sequence_lengths = [3]
with ops.device("/cpu:0"):
beams = beam_search_ops.gather_tree(
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
with self.cached_session():
with self.assertRaisesOpError(
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
_ = beams.eval()
beams = beam_search_ops.gather_tree(
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
self.evaluate(beams)
def testBadParentValuesOnGPU(self):
# Only want to run this test on CUDA devices, as gather_tree is not
@ -93,8 +92,7 @@ class GatherTreeTest(test.TestCase):
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=end_token)
with self.session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
self.assertAllEqual(expected_result, self.evaluate(beams))
def testGatherTreeBatch(self):
batch_size = 10
@ -103,7 +101,7 @@ class GatherTreeTest(test.TestCase):
max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
end_token = 5
with self.session(use_gpu=True):
with self.cached_session(use_gpu=True):
step_ids = np.random.randint(
0, high=end_token + 1, size=(max_time, batch_size, beam_width))
parent_ids = np.random.randint(
@ -116,7 +114,7 @@ class GatherTreeTest(test.TestCase):
end_token=end_token)
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
beams_value = beams.eval()
beams_value = self.evaluate(beams)
for b in range(batch_size):
# Past max_sequence_lengths[b], we emit all end tokens.
b_value = beams_value[max_sequence_lengths[b]:, b, :]

View File

@ -13,26 +13,25 @@
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.seq2seq.decoder."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
@test_util.run_v1_only
class DynamicDecodeRNNTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None):

View File

@ -31,7 +31,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class LossTest(test.TestCase):
def setUp(self):
def config_default_values(self):
self.batch_size = 2
self.sequence_length = 3
self.number_of_classes = 5
@ -56,7 +56,8 @@ class LossTest(test.TestCase):
self.expected_loss = 1.60944
def testSequenceLoss(self):
with self.test_session(use_gpu=True):
self.config_default_values()
with self.cached_session(use_gpu=True):
average_loss_per_example = loss.sequence_loss(
self.logits, self.targets, self.weights,
average_across_timesteps=True,
@ -90,7 +91,8 @@ class LossTest(test.TestCase):
self.assertAllClose(compare_total, res)
def testSequenceLossClass(self):
with self.test_session(use_gpu=True):
self.config_default_values()
with self.cached_session(use_gpu=True):
seq_loss = loss.SequenceLoss(average_across_timesteps=True,
average_across_batch=True,
sum_over_timesteps=False,
@ -132,7 +134,8 @@ class LossTest(test.TestCase):
self.assertAllClose(compare_total, res)
def testSumReduction(self):
with self.test_session(use_gpu=True):
self.config_default_values()
with self.cached_session(use_gpu=True):
seq_loss = loss.SequenceLoss(average_across_timesteps=False,
average_across_batch=False,
sum_over_timesteps=True,
@ -174,6 +177,7 @@ class LossTest(test.TestCase):
self.assertAllClose(compare_total, res)
def testWeightedSumReduction(self):
self.config_default_values()
weights = [
constant_op.constant(1.0, shape=[self.batch_size])
for _ in range(self.sequence_length)
@ -181,7 +185,7 @@ class LossTest(test.TestCase):
# Make the last element in the sequence to have zero weights.
weights[-1] = constant_op.constant(0.0, shape=[self.batch_size])
self.weights = array_ops.stack(weights, axis=1)
with self.test_session(use_gpu=True):
with self.cached_session(use_gpu=True):
seq_loss = loss.SequenceLoss(average_across_timesteps=False,
average_across_batch=False,
sum_over_timesteps=True,
@ -225,12 +229,13 @@ class LossTest(test.TestCase):
self.assertAllClose(compare_total, res)
def testZeroWeights(self):
self.config_default_values()
weights = [
constant_op.constant(0.0, shape=[self.batch_size])
for _ in range(self.sequence_length)
]
weights = array_ops.stack(weights, axis=1)
with self.test_session(use_gpu=True):
with self.cached_session(use_gpu=True):
average_loss_per_example = loss.sequence_loss(
self.logits, self.targets, weights,
average_across_timesteps=True,

View File

@ -2347,7 +2347,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell):
if self._initial_cell_state is not None:
cell_state = self._initial_cell_state
else:
cell_state = self._cell.zero_state(batch_size, dtype)
cell_state = self._cell.get_initial_state(batch_size=batch_size,
dtype=dtype)
error_message = (
"When calling zero_state of AttentionWrapper %s: " % self._base_name +
"Non-matching batch sizes between the memory "

View File

@ -218,7 +218,7 @@ def _check_batch_beam(t, batch_size, beam_width):
"incompatible with the dynamic shape of %s elements. "
"Consider setting reorder_tensor_arrays to False to disable "
"TensorArray reordering during the beam search."
% (t.name))
% (t if context.executing_eagerly() else t.name))
rank = t.shape.ndims
shape = array_ops.shape(t)
if rank == 2:

View File

@ -233,7 +233,6 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS
ADDITIONAL_CORE_PROTO_SRCS = [
"example/example_parser_configuration.proto",
"protobuf/trackable_object_graph.proto",
"protobuf/autotuning.proto",
"protobuf/control_flow.proto",
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# "protobuf/critical_section.proto",
@ -926,6 +925,7 @@ tf_cuda_library(
"framework/tensor_slice.h",
"framework/tensor_types.h",
"framework/tensor_util.h",
"framework/thread_factory.h",
"framework/tracking_allocator.h",
"framework/type_index.h",
"framework/type_traits.h",
@ -1671,6 +1671,7 @@ filegroup(
":protos_all_proto_text_srcs",
":error_codes_proto_text_srcs",
"//tensorflow/core/platform/default/build_config:android_srcs",
"//tensorflow/c:srcs",
] + glob(
[
"client/**/*.cc",

View File

@ -0,0 +1,32 @@
op {
graph_op_name: "ExperimentalAutoShardDataset"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the input dataset.
END
}
in_arg {
name: "num_workers"
description: <<END
A scalar representing the number of workers to distribute this dataset across.
END
}
in_arg {
name: "index"
description: <<END
A scalar representing the index of the current worker out of num_workers.
END
}
summary: "Creates a dataset that shards the input dataset."
description: <<END
Creates a dataset that shards the input dataset by num_workers, returning a
sharded dataset for the index-th worker. This attempts to automatically shard
a dataset by examining the Dataset graph and inserting a shard op before the
inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset).
This dataset will throw a NotFound error if we cannot shard the dataset
automatically.
END
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <unordered_map>
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/thread_factory.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
@ -287,7 +289,8 @@ class IteratorContext {
model(ctx->model()),
runner(*(ctx->runner())),
runner_threadpool_size(ctx->runner_threadpool_size()),
stats_aggregator(ctx->stats_aggregator()) {}
stats_aggregator(ctx->stats_aggregator()),
thread_factory(ctx->thread_factory()) {}
explicit Params(OpKernelContext* ctx)
: env(ctx->env()),
@ -338,6 +341,10 @@ class IteratorContext {
// The `StatsAggregator` object to record statistics about the iterator.
std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
// A `ThreadFactory` for creating threads used by iterators to perform
// blocking work.
std::shared_ptr<ThreadFactory> thread_factory = nullptr;
};
explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
@ -374,6 +381,20 @@ class IteratorContext {
return &params_.runner;
}
const std::shared_ptr<ThreadFactory>& thread_factory() {
return params_.thread_factory;
}
std::unique_ptr<Thread> StartThread(const string& name,
std::function<void()> fn) {
if (params_.thread_factory) {
return params_.thread_factory->StartThread(name, std::move(fn));
} else {
return absl::WrapUnique(
Env::Default()->StartThread({}, name, std::move(fn)));
}
}
int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
std::shared_ptr<StatsAggregator> stats_aggregator() {

View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_
#define TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_
#include <functional>
#include <memory>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class Thread;
// Virtual interface for an object that creates threads.
class ThreadFactory {
public:
virtual ~ThreadFactory() {}
// Runs `fn` asynchronously in a different thread. `fn` may block.
//
// NOTE: The caller is responsible for ensuring that this `ThreadFactory`
// outlives the returned `Thread`.
virtual std::unique_ptr<Thread> StartThread(const string& name,
std::function<void()> fn) = 0;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_

View File

@ -979,6 +979,41 @@ class SymbolicShapeRefiner {
return true;
}
// Return true if the annotated shape is compatible with shape inference
// result. Examples:
// Inferred shape: ?, annotated shape: [10, 10] -> true;
// Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
// Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
// Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
bool CompatibleShapes(ShapeHandle inferred_shape,
ShapeHandle annotated_shape) const {
if (inferred_shape.SameHandle(annotated_shape)) {
return true;
}
if (!InferenceContext::RankKnown(inferred_shape)) {
return true;
}
if (InferenceContext::Rank(inferred_shape) !=
InferenceContext::Rank(annotated_shape)) {
return false;
}
const int rank = InferenceContext::Rank(inferred_shape);
for (int i = 0; i < rank; ++i) {
if (!InferenceContext::DimKnownRank(inferred_shape, i)
.SameHandle(
InferenceContext::DimKnownRank(annotated_shape, i))) {
int64 val1 = InferenceContext::Value(
InferenceContext::DimKnownRank(inferred_shape, i));
int64 val2 = InferenceContext::Value(
InferenceContext::DimKnownRank(annotated_shape, i));
if (val1 >= 0 && val1 != val2) {
return false;
}
}
}
return true;
}
bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
const std::vector<ShapeAndType>& st2) const {
if (st1.size() != st2.size()) {
@ -1139,9 +1174,9 @@ class SymbolicShapeRefiner {
return true;
}
// Returns true if we want to update output values with running EvaluateNode()
// for this op, based on op type, data type, and size.
bool ShouldUpdateOutputValues(NodeContext* c, int64 max_size) {
// Returns true if we want to update output shapes and values with running
// EvaluateNode() for this op, based on op type, data type, and size.
bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
InferenceContext* ic = c->inference_context.get();
// Due to the cost of running EvaluateNode(), we limit only to white listed
@ -1232,8 +1267,9 @@ class SymbolicShapeRefiner {
}
}
// Run a node to infer output values, and add it to the NodeContext.
Status UpdateOutputValues(const NodeDef& node, NodeContext* c) {
// Run a node to infer output shapes and values, and add it to the
// NodeContext.
Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
InferenceContext* ic = c->inference_context.get();
// Input to EvaluateNode()
@ -1264,7 +1300,7 @@ class SymbolicShapeRefiner {
ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
if (ic->FullyDefined(ic->output(k)) &&
!EquivalentShapes(ic->output(k), output_shape)) {
LOG(WARNING) << "UpdateOutputValues() -- node: " << node.name()
LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
<< ", inferred output shape "
<< "doesn't match for k=" << k << ": "
<< "ic->output(k): " << ic->DebugString(ic->output(k))
@ -1284,6 +1320,54 @@ class SymbolicShapeRefiner {
return Status::OK();
}
// Update output shapes with annotated information.
// Currently only handle nodes with static shapes, i.e. shapes do not change
// during execution.
// TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
NodeContext* c) const {
const auto& attr = node.attr();
if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
attr.count(kOutputShapes) == 0)
return Status::OK();
InferenceContext* ic = c->inference_context.get();
int output_size = attr.at(kOutputShapes).list().shape_size();
for (int i = 0; i < ic->num_outputs(); i++) {
// Annotated Switch node has only one output. Propagate the shape to all
// the outputs.
int shape_index = IsSwitch(node) ? 0 : i;
if (shape_index >= output_size) {
LOG(WARNING)
<< "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
<< node.name() << ", inferred output shape size "
<< ic->num_outputs() << ", annotated output shape size "
<< output_size;
break;
}
const TensorShapeProto& shape =
attr.at(kOutputShapes).list().shape(shape_index);
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
// Only use annotated shapes if the inference shape is unknown and
// compatible with annotated shapes.
if (!ic->FullyDefined(ic->output(i)) &&
CompatibleShapes(ic->output(i), output_shape)) {
VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
<< node.name() << ", inferred output shape " << i << ": "
<< "ic->output(i): " << ic->DebugString(ic->output(i))
<< ", annotated output shape: " << ic->DebugString(output_shape)
<< " -- " << node.ShortDebugString();
ic->set_output(i, output_shape);
}
}
return Status::OK();
}
Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
NodeContext* c) {
// Propagate tensors and shape tensors unless the node is fed.
@ -1476,16 +1560,19 @@ class SymbolicShapeRefiner {
}
if (aggressive_shape_inference_) {
// Update output shapes with annotated information. This is optional.
UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
// Update output tensor values using EvaluateNode() if we can.
// Due to the cost of EvaluateNode(), we run it only for certain op types
// (white listed) and small integer tensors.
const int max_element_size = 17; // Max up to 4x4 matrix or similar.
if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
!ShouldUpdateOutputValues(c, max_element_size)) {
!ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
return Status::OK();
}
UpdateOutputValues(node, c).IgnoreError(); // This is optional.
UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
}
return Status::OK();
}
@ -1797,6 +1884,7 @@ Status GraphProperties::UpdateShapes(
// UpdateNode calls UpdateFunction if a function node is detected.
TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
}
return Status::OK();
}

View File

@ -27,6 +27,45 @@ namespace tensorflow {
namespace grappler {
// Optional attributes that tell about node output information.
// We use these side information, if provided, for static shape inference
// and VirtualScheduler scheduling.
// Switch op attribute as a vector of int that tells which branch the
// Switch output is taken on every round of execution.
// Used for scheduling ops after Switch correctly (e.g., While loop).
ABSL_CONST_INIT const char kOutputSlots[] = "_output_slot_vector";
// Example:
// Assume a node has two outputs and iterated for three times. Then it has:
// _execution_count = 3
// _output_sizes_vector = [2, 2, 2]
// _output_dtype_vector.size = 6
// _output_shape_vector.size = 6
// If all the iterations have same output shapes, then
// _execution_count = 3
// _same_output_for_iterations = true
// _output_sizes_vector = [2]
// _output_dtype_vector.size = 2
// _output_shape_vector.size = 2
// How many times this node has been executed.
ABSL_CONST_INIT const char kExecutionCount[] = "_execution_count";
// Records the output sizes for each round of execution.
ABSL_CONST_INIT const char kOutputSizes[] = "_output_sizes_vector";
// The node has been scheduled multiple times with outputs that have the same
// shape.
ABSL_CONST_INIT const char kOutputSame[] = "_same_output_for_iterations";
// Outputs DataType vector.
ABSL_CONST_INIT const char kOutputTypes[] = "_output_dtype_vector";
// Outputs TensorShapeProto vector.
ABSL_CONST_INIT const char kOutputShapes[] = "_output_shape_vector";
class SymbolicShapeRefiner;
class TopoQueue;

View File

@ -1793,6 +1793,103 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
ExpectTensorValues({20, 24}, c_plus_b_plus_2a_prop.value());
}
TEST_F(GraphPropertiesTest, ShapeAnnotation) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
.Attr("dtype", DT_FLOAT)
.Attr("shape", PartialTensorShape({-1, -1}))
.Finalize(item.graph.add_node()));
// Annotate shapes.
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
.Attr("dtype", DT_FLOAT)
.Attr("_same_output_for_iterations", true)
.Attr("_output_shape_vector", {TensorShape({5, 7})})
.Input("Input", 0, DT_FLOAT)
.Finalize(item.graph.add_node()));
{
GraphProperties properties(item);
// Without aggressive_shape_inference, ignore annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ(2, prop.shape().dim_size());
// Get unknown shapes without using annotated information.
EXPECT_EQ("float: [-1,-1]", PropToString(prop));
}
{
GraphProperties properties(item);
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ(2, prop.shape().dim_size());
// Update output shape using annotated shapes.
EXPECT_EQ("float: [5,7]", PropToString(prop));
}
}
TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
.Attr("dtype", DT_FLOAT)
.Attr("shape", PartialTensorShape({-1, 100}))
.Finalize(item.graph.add_node()));
// Annotate shapes.
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
.Attr("dtype", DT_FLOAT)
.Attr("_same_output_for_iterations", true)
.Attr("_output_shape_vector", {TensorShape({10, 100})})
.Input("Input", 0, DT_FLOAT)
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ(2, prop.shape().dim_size());
// Compatible shapes. Update output shape using annotated shapes.
EXPECT_EQ("float: [10,100]", PropToString(prop));
}
TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
GrapplerItem item;
TF_CHECK_OK(NodeDefBuilder("Input", "Placeholder")
.Attr("dtype", DT_FLOAT)
.Attr("shape", PartialTensorShape({-1, 100}))
.Finalize(item.graph.add_node()));
// Annotate shapes.
TF_CHECK_OK(NodeDefBuilder("Identity", "Identity")
.Attr("dtype", DT_FLOAT)
.Attr("_same_output_for_iterations", true)
.Attr("_output_shape_vector", {TensorShape({10, 10})})
.Input("Input", 0, DT_FLOAT)
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
EXPECT_EQ(2, prop.shape().dim_size());
// Incompatible shapes. Do not use annotated shapes.
EXPECT_EQ("float: [-1,100]", PropToString(prop));
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -36,12 +36,6 @@ namespace tensorflow {
namespace grappler {
namespace {
// Optional attribute name for Switch op as a vector of int that tells
// which branch the Switch output is taken on every round of execution.
// We use this side information, if provided, for scheduling ops after Switch
// correctly (e.g., While loop).
constexpr char kOutputSlots[] = "_output_slot_vector";
Costs CombineCosts(const Costs& left, const Costs& right) {
CHECK_NE(left.max_memory, kMemoryUnknown);
CHECK_NE(left.max_per_op_buffers, kMemoryUnknown);

View File

@ -29,6 +29,27 @@ cc_library(
],
)
cc_library(
name = "auto_shard",
srcs = ["auto_shard.cc"],
hdrs = ["auto_shard.h"],
deps = [
":graph_utils",
":optimizer_base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler/utils:functions",
] + tf_protos_all(),
alwayslink = 1,
)
cc_library(
name = "filter_fusion",
srcs = ["filter_fusion.cc"],

View File

@ -0,0 +1,300 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace grappler {
namespace {
// clang-format off
constexpr char kShardDatasetOpName[] = "ShardDataset";
constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
constexpr std::array<const char*, 4> kReaderDatasetOps = {
"FixedLengthRecordDataset",
"FixedLengthRecordDatasetV2",
"TextLineDataset",
"TFRecordDataset"
};
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ConcatenateDataset",
"ZipDataset"
};
constexpr std::array<const char*, 22> kPassThroughOps = {
"BatchDataset",
"BatchDatasetV2",
"ExperimentalMapAndBatchDataset",
"PaddedBatchDataset",
"PaddedBatchDatasetV2",
"CacheDataset",
"FilterDataset",
"FilterByLastComponentDataset",
"Identity",
"MapDataset",
"ModelDataset",
"OptimizeDataset",
"ParallelMapDataset",
"PrefetchDataset",
"ReduceDataset",
"RepeatDataset",
"ShardDataset",
"ShuffleAndRepeatDataset",
"ShuffleDataset",
"SkipDataset",
"TakeDataset",
"WindowDataset"
};
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
constexpr std::array<const char*, 4> kFuncDatasetOps = {
"ExperimentalParallelInterleaveDataset",
"FlatMapDataset",
"InterleaveDataset",
"ParallelInterleaveDatasetV2"
};
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
"GeneratorDataset",
"RangeDataset",
"SparseTensorsSliceDataset",
"TensorDataset",
"TensorSliceDataset",
};
// clang-format on
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
GraphDef* output);
template <std::size_t SIZE>
bool IsDatasetNodeOfType(const NodeDef& node,
const std::array<const char*, SIZE>& arr) {
for (const auto& dataset_op_name : arr) {
if (node.op() == dataset_op_name) return true;
}
return false;
}
Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
int64 num_workers, int64 index) {
NodeDef new_node;
new_node.set_op(kShardDatasetOpName);
graph_utils::SetUniqueGraphNodeName(kShardDatasetOpName, graph->graph(),
&new_node);
// Construct argument nodes
NodeDef* num_shards_node =
graph_utils::AddScalarConstNode<int64>(num_workers, graph);
NodeDef* index_node = graph_utils::AddScalarConstNode<int64>(index, graph);
// Add inputs to new node
new_node.add_input(add_before.input(0));
new_node.add_input(num_shards_node->name());
new_node.add_input(index_node->name());
// Add shapes and other attributes
NodeDef* add_after = graph->GetNode(add_before.input(0));
graph_utils::CopyAttribute("output_shapes", *add_after, &new_node);
if (add_after->attr().find("Toutput_types") != add_after->attr().end()) {
(*(new_node.mutable_attr()))["output_types"] =
add_after->attr().at("Toutput_types");
} else {
graph_utils::CopyAttribute("output_types", *add_after, &new_node);
}
// Add new node into graph and update edges
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
TF_RETURN_IF_ERROR(
graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
return Status::OK();
}
bool ReaderOpInFunction(const NodeDef& node,
const FunctionLibraryDefinition& flib) {
const FunctionDef* func = flib.Find(node.attr().at("f").func().name());
for (int i = 0; i < func->node_def_size(); i++) {
NodeDef node_in_func = func->node_def(i);
if (IsDatasetNodeOfType(node_in_func, kReaderDatasetOps) &&
node_in_func.input_size() > 0 &&
str_util::StartsWith(node_in_func.input(0), "args_0")) {
return true;
}
if (IsDatasetNodeOfType(func->node_def(i), kFuncDatasetOps) &&
ReaderOpInFunction(func->node_def(i), flib)) {
return true;
}
}
return false;
}
Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
absl::flat_hash_set<string>* nodes_to_delete) {
if (node.op() == kShuffleDatasetOpName) {
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
nodes_to_delete->insert(node.name());
}
for (const auto& fanin : graph->GetFanins(node, true)) {
TF_RETURN_IF_ERROR(
RemoveShuffleDataset(graph, *fanin.node, nodes_to_delete));
}
// TODO(frankchn): Traverse functions too.
return Status::OK();
}
Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
FunctionLibraryDefinition* flib,
MutableGraphView* graph,
absl::flat_hash_set<string>* nodes_to_delete) {
if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
return errors::NotFound("Found an unshardable source dataset: ",
node.DebugString());
}
if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
for (int i = 0; i < node.input_size(); ++i) {
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, index,
flib, graph, nodes_to_delete));
}
return Status::OK();
}
// This handles the case where a reader Dataset is contained within a
// FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
//
// dataset = Dataset.list_files("/path/to/data")
// dataset = dataset.flat_map(core_readers.TFRecordDataset)
//
// where the list of files is passed in one-by-one as an argument to the
// function in flat_map.
if (IsDatasetNodeOfType(node, kFuncDatasetOps) &&
ReaderOpInFunction(node, *flib)) {
TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
TF_RETURN_IF_ERROR(RemoveShuffleDataset(graph, node, nodes_to_delete));
return Status::OK();
}
if (IsDatasetNodeOfType(node, kReaderDatasetOps)) {
// We reached a reader dataset directly and we try to shard input 0.
TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
TF_RETURN_IF_ERROR(RemoveShuffleDataset(graph, node, nodes_to_delete));
return Status::OK();
}
if (!IsDatasetNodeOfType(node, kPassThroughOps)) {
return errors::NotFound(
"Did not find a shardable source, walked to ",
"a node which is not a dataset: ", node.DebugString());
}
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
nodes_to_delete);
}
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
GraphDef* output) {
*output = item.graph;
MutableGraphView graph(output);
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
NodeDef target_node;
absl::flat_hash_set<string> nodes_to_delete;
// The basic approach here is to walk the graph from sink to source, and find
// the latest occurrence of a ReaderDataset (e.g. CSVDataset, TFRecordDataset,
// etc...). We then add a shard after that dataset to shard the outputs of
// that dataset, in effect giving a piece to each worker. Finally, we remove
// occurences from randomness from before that point in the graph (e.g. things
// like ShuffleDataset) to ensure that `shard` returns a sensible result.
NodeDef sink_node;
TF_RETURN_IF_ERROR(graph_utils::FindSinkNode(item.graph, &sink_node));
TF_RETURN_IF_ERROR(RecursivelyHandleOp(sink_node, num_workers, index, &flib,
&graph, &nodes_to_delete));
TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
return Status::OK();
}
} // anonymous namespace
Status AutoShard::Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
if (!config) return errors::InvalidArgument("RewriterConfig not found.");
if ((config->parameter_map().find("num_workers") ==
config->parameter_map().end())) {
return errors::InvalidArgument("num_workers parameter missing.");
}
if ((config->parameter_map().find("index") ==
config->parameter_map().end())) {
return errors::InvalidArgument("index parameter missing.");
}
num_workers_ = config->parameter_map().at("num_workers").i();
index_ = config->parameter_map().at("index").i();
if (num_workers_ < 1) {
return errors::InvalidArgument("num_workers should be >= 1, currently ",
num_workers_);
}
if (index_ < 0 || index_ >= num_workers_) {
return errors::InvalidArgument("index should be >= 0 and < ", num_workers_,
", currently ", index_);
}
return Status::OK();
}
Status AutoShard::OptimizeAndCollectStats(Cluster* /* cluster */,
const GrapplerItem& item,
GraphDef* output,
OptimizationStats* stats) {
*output = item.graph;
MutableGraphView graph(output);
TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_, output));
stats->num_changes++;
return Status::OK();
}
REGISTER_GRAPH_OPTIMIZER_AS(AutoShard, "tf_auto_shard");
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,53 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
namespace tensorflow {
namespace grappler {
// AutoShard takes a Dataset graph and tries to insert a shard node
// automatically before a ReaderDataset (e.g. a CSVDataset or a TFRecordDataset)
// such that the dataset is sharded without any modifications to the original
// dataset-based input pipeline.
class AutoShard : public TFDataOptimizerBase {
public:
AutoShard() = default;
~AutoShard() override = default;
string name() const override { return "tf_auto_shard"; }
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override;
Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
GraphDef* output,
OptimizationStats* stats) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override {}
private:
int64 num_workers_;
int64 index_;
};
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_

View File

@ -300,6 +300,40 @@ Status EnsureNodeNamesUnique(Graph* g) {
return Status::OK();
}
// Tries to find a Sink node in the graph. A sink node is defined as a node
// that has at least one input and no outputs. If there are multiple of these,
// this might return any one of them. This is useful to identify the final
// Dataset op in the graph but in some cases there might be multiple Identity
// ops added to the end and this would return the last Identity op in that case.
Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node) {
absl::flat_hash_map<string, int> all_node_names;
absl::flat_hash_map<string, int> node_input_map;
for (int i = 0; i < graph_def.node_size(); ++i) {
all_node_names.insert_or_assign(graph_def.node(i).name(), i);
node_input_map.insert_or_assign(graph_def.node(i).name(), 0);
}
// Counts how many graph nodes for each input name. Candidate sink
// nodes are ones which are inputs into zero nodes.
for (const NodeDef& node : graph_def.node()) {
for (const string& input_name : node.input()) {
node_input_map[input_name]++;
}
}
for (const auto& it : node_input_map) {
if (it.second == 0) {
const NodeDef& sink_graph_node = graph_def.node(all_node_names[it.first]);
if (sink_graph_node.input_size() == 0) {
continue;
}
*sink_node = sink_graph_node;
return Status::OK();
}
}
return errors::InvalidArgument("Failed to find a sink node");
}
} // namespace graph_utils
} // namespace grappler
} // namespace tensorflow

View File

@ -144,6 +144,9 @@ void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
// and renaming nodes does not mutate any edges.
Status EnsureNodeNamesUnique(Graph* g);
// Returns the sink node (i.e. last node) in the graph.
Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node);
} // namespace graph_utils
} // namespace grappler
} // namespace tensorflow

View File

@ -270,6 +270,40 @@ TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
EXPECT_NE(const_0->name(), const_2->name());
}
TEST(GraphUtilsTest, TestFindSinkNodeStandard) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
AddNode("node1", "Identity", {}, {}, &graph);
AddNode("node2", "Identity", {"node1"}, {}, &graph);
NodeDef* node3 = AddNode("node3", "Identity", {"node2"}, {}, &graph);
NodeDef sink_node;
TF_EXPECT_OK(FindSinkNode(graph_def, &sink_node));
EXPECT_EQ(sink_node.name(), node3->name());
}
TEST(GraphUtilsTest, TestFindSinkNodeNoSingleSink) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
AddNode("node1", "Identity", {}, {}, &graph);
AddNode("node2", "Identity", {}, {}, &graph);
NodeDef sink_node;
Status s = FindSinkNode(graph_def, &sink_node);
EXPECT_FALSE(s.ok());
}
TEST(GraphUtilsTest, TestFindSinkNodeGraphDefEmpty) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef sink_node;
Status s = FindSinkNode(graph_def, &sink_node);
EXPECT_FALSE(s.ok());
}
} // namespace
} // namespace graph_utils
} // namespace grappler

View File

@ -172,39 +172,6 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers,
return Status::OK();
}
// There is one Sink node at least that is added to the end of the graph. We
// find that node and return it. It is possible that there are multiple
// Identity ops from the final Dataset op to that Sink node, but the recursive
// graph traversal handles that.
Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node) {
absl::flat_hash_map<string, int> all_node_names;
absl::flat_hash_map<string, int> node_input_map;
for (int i = 0; i < graph_def.node_size(); ++i) {
all_node_names.insert_or_assign(graph_def.node(i).name(), i);
node_input_map.insert_or_assign(graph_def.node(i).name(), 0);
}
// Counts how many graph nodes is this node the input to. Candidate sink
// nodes are ones which are inputs into zero nodes.
for (const NodeDef& node : graph_def.node()) {
for (const string& input_name : node.input()) {
node_input_map[input_name]++;
}
}
for (const auto& it : node_input_map) {
if (it.second == 0) {
const NodeDef& sink_graph_node = graph_def.node(all_node_names[it.first]);
// Sometimes the searching surfaces Arg nodes in function cases that
// have no input. This check rejects those.
if (sink_graph_node.input_size() == 0) {
continue;
}
*sink_node = sink_graph_node;
return Status::OK();
}
}
return errors::InvalidArgument("Failed to find a sink node");
}
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
GraphDef* output);
@ -282,7 +249,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
NodeDef sink_node;
TF_RETURN_IF_ERROR(FindSinkNode(item.graph, &sink_node));
TF_RETURN_IF_ERROR(graph_utils::FindSinkNode(item.graph, &sink_node));
TF_RETURN_IF_ERROR(
RecursivelyHandleOp(sink_node, num_workers, &flib, &graph));
*output->mutable_library() = flib.ToProto();

View File

@ -21,10 +21,6 @@ limitations under the License.
namespace tensorflow {
namespace {
constexpr float kLayerByLayerTreeWeight = 1.0;
} // namespace
// Constructor.
BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
: tree_ensemble_(

View File

@ -129,6 +129,29 @@ tf_cc_test(
],
)
cc_library(
name = "unbounded_thread_pool",
srcs = ["unbounded_thread_pool.cc"],
hdrs = ["unbounded_thread_pool.h"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
],
)
tf_cc_test(
name = "unbounded_thread_pool_test",
srcs = ["unbounded_thread_pool_test.cc"],
deps = [
":unbounded_thread_pool",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "window_dataset",
srcs = ["window_dataset.cc"],
@ -595,6 +618,7 @@ tf_kernel_library(
deps = [
":dataset_utils",
":optional_ops",
":unbounded_thread_pool",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@ -612,6 +636,7 @@ tf_kernel_library(
srcs = ["multi_device_iterator_ops.cc"],
deps = [
":dataset_utils",
":unbounded_thread_pool",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -54,6 +54,21 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "auto_shard_dataset_op",
srcs = ["auto_shard_dataset_op.cc"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/optimizers/data:auto_shard",
"//tensorflow/core/kernels/data:graph_rewrite_dataset",
],
)
tf_kernel_library(
name = "group_by_reducer_dataset_op",
srcs = ["group_by_reducer_dataset_op.cc"],
@ -390,6 +405,7 @@ tf_kernel_library(
name = "dataset_kernels",
deps = [
":assert_next_dataset_op",
":auto_shard_dataset_op",
":choose_fastest_dataset_op",
":csv_dataset_op",
":dense_to_sparse_batch_dataset_op",

View File

@ -0,0 +1,118 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kOptimizerName[] = "tf_auto_shard";
class AutoShardDatasetOp : public UnaryDatasetOpKernel {
public:
explicit AutoShardDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 index;
int64 num_workers;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers));
OP_REQUIRES(
ctx, num_workers > 0,
errors::InvalidArgument("num_workers must be greater than zero."));
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "index", &index));
OP_REQUIRES(ctx, index >= 0 && index < num_workers,
errors::InvalidArgument("index must be between 0 and ",
num_workers - 1));
Dataset* dataset = new Dataset(ctx, input, num_workers, index,
output_types_, output_shapes_);
const Status s = dataset->Optimize(ctx);
if (s.ok()) {
*output = dataset;
} else {
dataset->Unref();
OP_REQUIRES_OK(ctx, s);
}
}
private:
class Dataset : public GraphRewriteDataset {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const int64 num_workers, const int64 index,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: GraphRewriteDataset(ctx, input, output_types, output_shapes),
num_workers_(num_workers),
index_(index) {}
string DebugString() const override {
return "AutoShardDatasetOp::Dataset";
}
private:
bool ShouldOptimizeFunctions() override {
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
return false;
}
RewriterConfig CreateGrapplerRewriteConfig() override {
RewriterConfig rewriter_config;
rewriter_config.set_fail_on_optimizer_errors(true);
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
AttrValue num_workers_attr;
num_workers_attr.set_i(num_workers_);
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
num_workers_attr;
AttrValue index_attr;
index_attr.set_i(index_);
(*custom_optimizer->mutable_parameter_map())["index"] = index_attr;
return rewriter_config;
}
const int64 num_workers_;
const int64 index_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("ExperimentalAutoShardDataset").Device(DEVICE_CPU),
AutoShardDatasetOp);
} // anonymous namespace
} // namespace data
} // namespace tensorflow

View File

@ -292,10 +292,10 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
for (size_t i = 0, num_inputs = dataset()->inputs_.size();
i < num_inputs; ++i) {
threads[i].result = absl::make_unique<InvocationResult>();
threads[i].thread.reset(ctx->env()->StartThread(
{}, strings::StrCat("tf_data_merge_", i),
threads[i].thread = ctx->StartThread(
strings::StrCat("tf_data_merge_", i),
std::bind(&ChooseFastestIterator::RunnerThread, this, ctx,
threads[i].result.get(), i)));
threads[i].result.get(), i));
}
return threads;
}

View File

@ -514,9 +514,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_map_and_batch",
std::bind(&Iterator::RunnerThread, this, ctx_copy)));
runner_thread_ = ctx->StartThread(
"tf_data_map_and_batch",
std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}

View File

@ -926,8 +926,8 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
if (!new_ctx) {
new_ctx = std::make_shared<IteratorContext>(*ctx);
}
workers_[i]->threads.emplace_back(ctx->env()->StartThread(
{}, strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j),
workers_[i]->threads.emplace_back(ctx->StartThread(
strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j),
[this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); }));
VLOG(3) << "Worker " << i << ", " << j << " successfully started.";
}
@ -936,9 +936,9 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
if (!new_ctx) {
new_ctx = std::make_shared<IteratorContext>(*ctx);
}
runner_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_numa_map_and_batch",
[this, new_ctx] { RunnerThread(new_ctx); }));
runner_thread_ =
ctx->StartThread("tf_data_numa_map_and_batch",
[this, new_ctx] { RunnerThread(new_ctx); });
}
VLOG(3) << "All workers & runner thread started.";
return Status::OK();

View File

@ -493,8 +493,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, strings::StrCat("tf_data_parallel_interleave_worker_", i),
worker_threads_.emplace_back(ctx->StartThread(
strings::StrCat("tf_data_parallel_interleave_worker_", i),
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
@ -592,8 +592,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
workers_[i].SetInputs(s, std::move(args));
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, strings::StrCat("tf_data_parallel_interleave_worker_", i),
worker_threads_.push_back(ctx->StartThread(
strings::StrCat("tf_data_parallel_interleave_worker_", i),
[this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/iterator_ops.h"
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@ -51,14 +53,15 @@ const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
class IteratorResource : public ResourceBase {
public:
IteratorResource(const DataTypeVector& output_dtypes,
IteratorResource(Env* env, const DataTypeVector& output_dtypes,
const std::vector<PartialTensorShape>& output_shapes,
const int /*unused: graph_def_version*/,
std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib)
: device_mgr_(std::move(device_mgr)),
: unbounded_thread_pool_(env, "tf_data_iterator_resource"),
device_mgr_(std::move(device_mgr)),
iterator_state_(std::make_shared<State>(
std::move(flib_def), std::move(pflr), lib, nullptr /* iterator */)),
output_dtypes_(output_dtypes),
@ -77,6 +80,7 @@ class IteratorResource : public ResourceBase {
params.function_handle_cache =
captured_state->function_handle_cache.get();
params.resource_mgr = &captured_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
return captured_state->iterator->GetNext(
IteratorContext(std::move(params)), out_tensors, end_of_sequence);
} else {
@ -163,6 +167,8 @@ class IteratorResource : public ResourceBase {
params.lib = new_state->lib;
params.function_handle_cache = new_state->function_handle_cache.get();
params.resource_mgr = &new_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
"Iterator", &new_state->iterator));
TF_RETURN_IF_ERROR(
@ -179,6 +185,7 @@ class IteratorResource : public ResourceBase {
params.allocator_getter = [device](AllocatorAttributes attrs) {
return device->GetAllocator(attrs);
};
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
IteratorContext iter_ctx(std::move(params));
TF_RETURN_IF_ERROR(new_state->iterator->Restore(&iter_ctx, reader));
}
@ -233,6 +240,7 @@ class IteratorResource : public ResourceBase {
params.lib = new_state->lib;
params.function_handle_cache = new_state->function_handle_cache.get();
params.resource_mgr = &new_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
"Iterator", &iterator));
TF_RETURN_IF_ERROR(
@ -284,6 +292,7 @@ class IteratorResource : public ResourceBase {
std::unique_ptr<IteratorBase> iterator;
};
UnboundedThreadPool unbounded_thread_pool_;
mutex mu_;
const std::unique_ptr<DeviceMgr> device_mgr_ GUARDED_BY(mu_);
std::shared_ptr<State> iterator_state_ GUARDED_BY(mu_);
@ -432,14 +441,14 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
context,
mgr->LookupOrCreate<IteratorResource>(
cinfo_.container(), cinfo_.name(), &resource,
[lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
output_dtypes_, output_shapes_, graph_def_version_,
std::move(device_mgr), std::move(flib_def),
std::move(pflr), lib);
return Status::OK();
}));
[context, lib, &device_mgr, &flib_def, &pflr,
this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
context->env(), output_dtypes_, output_shapes_,
graph_def_version_, std::move(device_mgr),
std::move(flib_def), std::move(pflr), lib);
return Status::OK();
}));
Status s = VerifyResource(resource);
if (TF_PREDICT_FALSE(!s.ok())) {
@ -522,7 +531,7 @@ void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) {
existing_resource->Unref();
}
IteratorResource* new_resource = new IteratorResource(
output_dtypes_, output_shapes_, graph_def_version_,
context->env(), output_dtypes_, output_shapes_, graph_def_version_,
std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
// Create the resource with our chosen name under the resource lookup
// mutex to avoid another kernel racily creating a resource with this
@ -837,11 +846,12 @@ class OneShotIteratorOp : public AsyncOpKernel {
TF_RETURN_IF_ERROR(
ctx->resource_manager()->LookupOrCreate<IteratorResource>(
cinfo->container(), cinfo->name(), iterator,
[lib, this, &flib_def, &pflr](IteratorResource** ret)
[ctx, lib, this, &flib_def, &pflr](IteratorResource** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
output_dtypes_, output_shapes_, graph_def_version_,
nullptr, std::move(flib_def), std::move(pflr), lib);
ctx->env(), output_dtypes_, output_shapes_,
graph_def_version_, nullptr, std::move(flib_def),
std::move(pflr), lib);
return Status::OK();
}));

View File

@ -140,9 +140,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
if (!optimize_thread_) {
std::shared_ptr<IteratorContext> new_ctx =
std::make_shared<IteratorContext>(*ctx);
optimize_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_model",
[this, new_ctx]() { OptimizeThread(new_ctx); }));
optimize_thread_ = ctx->StartThread(
"tf_data_model", [this, new_ctx]() { OptimizeThread(new_ctx); });
}
return Status::OK();
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/random.h"
@ -42,14 +43,15 @@ using MultiDeviceIteratorCallback =
class MultiDeviceIterator : public ResourceBase {
public:
MultiDeviceIterator(
const DataTypeVector& output_types,
Env* env, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
const std::vector<string>& devices,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib,
std::unique_ptr<FunctionHandleCache> function_handle_cache)
: output_types_(output_types),
: unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"),
output_types_(output_types),
output_shapes_(output_shapes),
devices_(devices),
flib_def_(std::move(flib_def)),
@ -82,27 +84,25 @@ class MultiDeviceIterator : public ResourceBase {
*incarnation_id = incarnation_id_;
multi_device_buffer_ = absl::make_unique<MultiDeviceBuffer>(
devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator));
devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator),
this);
return Status::OK();
}
void GetNextFromShard(IteratorContext* ctx, int shard_num,
void GetNextFromShard(OpKernelContext* ctx, int shard_num,
int64 incarnation_id,
MultiDeviceIteratorCallback callback) {
if (ctx->lib() == lib_) {
tf_shared_lock l(mu_);
multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
std::move(callback));
} else {
IteratorContext::Params params(ctx);
params.lib = lib_;
params.function_handle_cache = function_handle_cache_.get();
params.resource_mgr = &resource_mgr_;
IteratorContext iter_ctx(std::move(params));
tf_shared_lock l(mu_);
multi_device_buffer_->GetNextFromShard(
&iter_ctx, shard_num, incarnation_id, std::move(callback));
}
tf_shared_lock l(mu_);
IteratorContext::Params params(ctx);
params.function_library = lib_def_;
params.lib = lib_;
params.function_handle_cache = function_handle_cache_.get();
params.resource_mgr = &resource_mgr_;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
IteratorContext iter_ctx(std::move(params));
multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
std::move(callback));
}
const DataTypeVector& output_types() const { return output_types_; }
@ -133,12 +133,14 @@ class MultiDeviceIterator : public ResourceBase {
class MultiDeviceBuffer {
public:
MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
std::unique_ptr<IteratorBase> host_iterator)
std::unique_ptr<IteratorBase> host_iterator,
MultiDeviceIterator* parent)
: buffer_(size),
size_(size),
max_buffer_size_(max_buffer_size),
incarnation_id_(incarnation_id),
host_iterator_(std::move(host_iterator)) {}
host_iterator_(std::move(host_iterator)),
parent_(parent) {}
~MultiDeviceBuffer() {
{
@ -217,10 +219,12 @@ class MultiDeviceIterator : public ResourceBase {
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!background_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
background_thread_ = absl::WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_multi_device_iterator",
std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
this, std::move(ctx_copy))));
background_thread_ =
parent_->unbounded_thread_pool_.get_thread_factory()->StartThread(
"tf_data_multi_device_iterator",
std::bind(
&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
this, std::move(ctx_copy)));
}
}
@ -342,8 +346,10 @@ class MultiDeviceIterator : public ResourceBase {
const int64 max_buffer_size_;
const int64 incarnation_id_;
const std::unique_ptr<IteratorBase> host_iterator_;
MultiDeviceIterator* const parent_; // Not owned.
};
UnboundedThreadPool unbounded_thread_pool_;
mutex mu_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
@ -413,8 +419,9 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
current_id_.fetch_add(1));
container_name = "AnonymousMultiDeviceIterator";
resource = new MultiDeviceIterator(
output_types_, output_shapes_, devices_, std::move(flib_def),
std::move(pflr), lib, std::move(function_handle_cache));
context->env(), output_types_, output_shapes_, devices_,
std::move(flib_def), std::move(pflr), lib,
std::move(function_handle_cache));
// NOTE: `mgr->Create()` transfers the one reference on `resource` to
// `mgr`.
OP_REQUIRES_OK(context, mgr->Create<MultiDeviceIterator>(
@ -425,11 +432,12 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
OP_REQUIRES_OK(context,
mgr->LookupOrCreate<MultiDeviceIterator>(
container_name, unique_name, &resource,
[this, lib, &flib_def, &pflr,
[this, context, lib, &flib_def, &pflr,
&function_handle_cache](MultiDeviceIterator** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new MultiDeviceIterator(
output_types_, output_shapes_, devices_,
context->env(), output_types_,
output_shapes_, devices_,
std::move(flib_def), std::move(pflr),
lib, std::move(function_handle_cache));
return Status::OK();
@ -557,11 +565,8 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
},
std::placeholders::_1, std::move(done));
IteratorContext::Params params(ctx);
params.function_library = iterator->function_library();
IteratorContext iter_ctx(std::move(params));
iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
callback);
iterator->GetNextFromShard(ctx, shard_num, incarnation_id,
std::move(callback));
iterator->Unref();
},
std::move(done)));

View File

@ -517,17 +517,15 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!current_elements_manager_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
current_elements_manager_ =
absl::WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_parallel_interleave_current",
[this, new_ctx]() { CurrentElementsManager(new_ctx); }));
current_elements_manager_ = ctx->StartThread(
"tf_data_parallel_interleave_current",
[this, new_ctx]() { CurrentElementsManager(new_ctx); });
}
if (!future_elements_manager_) {
auto new_ctx = std::make_shared<IteratorContext>(*ctx);
future_elements_manager_ =
absl::WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_parallel_interleave_future",
[this, new_ctx]() { FutureElementsManager(new_ctx); }));
future_elements_manager_ = ctx->StartThread(
"tf_data_parallel_interleave_future",
[this, new_ctx]() { FutureElementsManager(new_ctx); });
}
}

View File

@ -191,9 +191,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "tf_data_parallel_map",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
}
}

View File

@ -269,9 +269,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
if (!prefetch_thread_) {
std::shared_ptr<IteratorContext> new_ctx =
std::make_shared<IteratorContext>(*ctx);
prefetch_thread_ = absl::WrapUnique<Thread>(ctx->env()->StartThread(
{}, "tf_data_prefetch",
[this, new_ctx]() { PrefetchThread(new_ctx); }));
prefetch_thread_ = ctx->StartThread(
"tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
}
return Status::OK();
}

View File

@ -0,0 +1,156 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace data {
// A lightweight wrapper for creating logical threads in a `UnboundedThreadPool`
// that can be shared (e.g.) in an `IteratorContext`.
class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory {
public:
explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {}
std::unique_ptr<Thread> StartThread(const string& name,
std::function<void()> fn) override {
return pool_->RunOnPooledThread(std::move(fn));
}
private:
UnboundedThreadPool* const pool_; // Not owned.
};
// A logical implementation of the `tensorflow::Thread` interface that uses
// physical threads in an `UnboundedThreadPool` to perform the work.
//
// NOTE: This object represents a logical thread of control that may be mapped
// onto the same physical thread as other work items that are submitted to the
// same `UnboundedThreadPool`.
class UnboundedThreadPool::LogicalThreadWrapper : public Thread {
public:
explicit LogicalThreadWrapper(std::shared_ptr<Notification> join_notification)
: join_notification_(std::move(join_notification)) {}
~LogicalThreadWrapper() override {
// NOTE: The `Thread` destructor is expected to "join" the created thread,
// but the physical thread may continue to execute after the work for this
// thread is complete. We simulate this by waiting on a notification that
// the `CachedThreadFunc` will notify when the thread's work function is
// complete.
join_notification_->WaitForNotification();
}
private:
std::shared_ptr<Notification> join_notification_;
};
UnboundedThreadPool::~UnboundedThreadPool() {
{
mutex_lock l(work_queue_mu_);
// Wake up all `CachedThreadFunc` threads and cause them to terminate before
// joining them when `threads_` is cleared.
cancelled_ = true;
work_queue_cv_.notify_all();
if (!work_queue_.empty()) {
LOG(ERROR) << "UnboundedThreadPool named \"" << thread_name_ << "\" was "
<< "deleted with pending work in its queue. This may indicate "
<< "a potential use-after-free bug.";
}
}
{
mutex_lock l(thread_pool_mu_);
// Clear the list of pooled threads, which will eventually terminate due to
// the previous notification.
//
// NOTE: It is safe to do this while holding `pooled_threads_mu_`, because
// no subsequent calls to `this->StartThread()` should be issued after the
// destructor starts.
thread_pool_.clear();
}
}
std::shared_ptr<ThreadFactory> UnboundedThreadPool::get_thread_factory() {
return std::make_shared<LogicalThreadFactory>(this);
}
size_t UnboundedThreadPool::size() {
tf_shared_lock l(thread_pool_mu_);
return thread_pool_.size();
}
std::unique_ptr<Thread> UnboundedThreadPool::RunOnPooledThread(
std::function<void()> fn) {
auto join_notification = std::make_shared<Notification>();
bool all_threads_busy;
{
// Enqueue a work item for the new thread's function, and wake up a
// cached thread to process it.
mutex_lock l(work_queue_mu_);
work_queue_.push_back({std::move(fn), join_notification});
work_queue_cv_.notify_one();
// NOTE: The queue may be non-empty, so we must account for queued work when
// considering how many threads are free.
all_threads_busy = work_queue_.size() > num_idle_threads_;
}
if (all_threads_busy) {
// Spawn a new physical thread to process the given function.
// NOTE: `PooledThreadFunc` will eventually increment `num_idle_threads_`
// at the beginning of its work loop.
Thread* new_thread = env_->StartThread(
{}, thread_name_,
std::bind(&UnboundedThreadPool::PooledThreadFunc, this));
mutex_lock l(thread_pool_mu_);
thread_pool_.emplace_back(new_thread);
}
return absl::make_unique<LogicalThreadWrapper>(std::move(join_notification));
}
void UnboundedThreadPool::PooledThreadFunc() {
while (true) {
WorkItem work_item;
{
mutex_lock l(work_queue_mu_);
++num_idle_threads_;
while (!cancelled_ && work_queue_.empty()) {
// Wait for a new work function to be submitted, or the cache to be
// destroyed.
work_queue_cv_.wait(l);
}
if (cancelled_) {
return;
}
work_item = std::move(work_queue_.front());
work_queue_.pop_front();
--num_idle_threads_;
}
work_item.work_function();
// Notify any thread that has "joined" the cached thread for this work item.
work_item.done_notification->Notify();
}
}
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_
#define TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_
#include <deque>
#include <memory>
#include <vector>
#include "tensorflow/core/framework/thread_factory.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace data {
// An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a
// potentially large number of "logical" threads onto a smaller number of
// "physical" threads. The multiplexing is achieved by maintaining an internal
// pool of long-running "physical" threads that are used to execute the
// "logical" threads. Like a regular thread, a "logical" thread may block on
// other threads, and the size of the pool will increase to ensure that progress
// is made. This mechanism is recommended in situations where short-lived
// threads are created repeatedly, to avoid the overhead and memory
// fragmentation that can result from excessive thread creation.
class UnboundedThreadPool {
public:
UnboundedThreadPool(Env* env, const string& thread_name)
: env_(env), thread_name_(thread_name) {}
~UnboundedThreadPool();
// Returns an implementation of `ThreadFactory` that can be used to create
// logical threads in this pool.
std::shared_ptr<ThreadFactory> get_thread_factory();
// Returns the current number of threads in this pool.
size_t size();
private:
class LogicalThreadFactory;
class LogicalThreadWrapper;
struct WorkItem {
std::function<void()> work_function;
std::shared_ptr<Notification> done_notification;
};
std::unique_ptr<Thread> RunOnPooledThread(std::function<void()> fn);
void PooledThreadFunc();
Env* const env_; // Not owned.
const string thread_name_;
mutex work_queue_mu_;
condition_variable work_queue_cv_ GUARDED_BY(work_queue_mu_);
size_t num_idle_threads_ GUARDED_BY(work_queue_mu_) = 0;
bool cancelled_ GUARDED_BY(work_queue_mu_) = false;
std::deque<WorkItem> work_queue_ GUARDED_BY(work_queue_mu_);
mutex thread_pool_mu_;
std::vector<std::unique_ptr<Thread>> thread_pool_ GUARDED_BY(thread_pool_mu_);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_UNBOUNDED_THREAD_POOL_H_

View File

@ -0,0 +1,143 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace data {
namespace {
TEST(UnboundedThreadPool, SingleThread) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
// Create a thread that updates a variable, and ensure that it runs to
// completion.
std::atomic<int> i(0);
auto thread = thread_factory->StartThread("", [&i]() { ++i; });
thread.reset();
EXPECT_GE(pool.size(), 1);
EXPECT_EQ(1, i);
}
TEST(UnboundedThreadPool, MultipleThreads) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
// Create ten threads that update a variable, and ensure that they all run
// to completion.
std::vector<std::unique_ptr<Thread>> threads;
const int kNumThreadsToCreate = 10;
std::atomic<int> i(0);
for (int j = 0; j < kNumThreadsToCreate; ++j) {
threads.push_back(thread_factory->StartThread("", [&i]() { ++i; }));
}
threads.clear();
EXPECT_GE(pool.size(), 1);
EXPECT_EQ(i, kNumThreadsToCreate);
}
TEST(UnboundedThreadPool, MultipleThreadsSleepingRandomly) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
// Create 1000 threads that sleep for a random period of time then update a
// variable, and ensure that they all run to completion.
std::vector<std::unique_ptr<Thread>> threads;
const int kNumThreadsToCreate = 1000;
std::atomic<int> i(0);
for (int j = 0; j < kNumThreadsToCreate; ++j) {
threads.push_back(thread_factory->StartThread("", [&i]() {
Env::Default()->SleepForMicroseconds(random::New64() % 10);
++i;
}));
}
threads.clear();
EXPECT_GE(pool.size(), 1);
EXPECT_EQ(i, kNumThreadsToCreate);
}
TEST(UnboundedThreadPool, ConcurrentThreadCreation) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
// Create ten threads that each create ten threads that update a variable, and
// ensure that they all run to completion.
std::vector<std::unique_ptr<Thread>> threads;
const int kNumThreadsToCreate = 10;
std::atomic<int> i(0);
for (int j = 0; j < kNumThreadsToCreate; ++j) {
threads.push_back(thread_factory->StartThread("", [&i, thread_factory]() {
std::vector<std::unique_ptr<Thread>> nested_threads;
for (int k = 0; k < kNumThreadsToCreate; ++k) {
nested_threads.push_back(
thread_factory->StartThread("", [&i]() { ++i; }));
}
nested_threads.clear();
}));
}
threads.clear();
EXPECT_GE(pool.size(), 1);
EXPECT_EQ(i, kNumThreadsToCreate * kNumThreadsToCreate);
}
TEST(UnboundedThreadPool, MultipleBlockingThreads) {
UnboundedThreadPool pool(Env::Default(), "test");
auto thread_factory = pool.get_thread_factory();
std::vector<std::unique_ptr<Thread>> threads;
// Create multiple waves (with increasing sizes) of threads that all block
// before returning, and
// ensure that we create the appropriate number of threads and terminate
// correctly.
std::vector<int> round_sizes = {5, 10, 15, 20};
for (const int round_size : round_sizes) {
Notification n;
BlockingCounter bc(round_size);
for (int j = 0; j < round_size; ++j) {
threads.push_back(thread_factory->StartThread("", [&bc, &n]() {
bc.DecrementCount();
// Block until `n` is notified, so that all ten threads must been
// created before the first one completes.
n.WaitForNotification();
}));
}
// Wait until all threads have started. Since the number of threads in each
// wave is increasing, we should have at least that number of threads in the
// pool.
bc.Wait();
// NOTE: There is a benign race between a new round starting and the
// physical threads from the previous round returning to the pool, so we may
// create more threads than the round_size.
EXPECT_GE(pool.size(), round_size);
n.Notify();
threads.clear();
}
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -133,6 +133,17 @@ int VarintLength(uint64_t v) {
return len;
}
const char* GetVarint32Ptr(const char* p, const char* limit, uint32* value) {
if (p < limit) {
uint32 result = *(reinterpret_cast<const unsigned char*>(p));
if ((result & 128) == 0) {
*value = result;
return p + 1;
}
}
return GetVarint32PtrFallback(p, limit, value);
}
const char* GetVarint32PtrFallback(const char* p, const char* limit,
uint32* value) {
uint32 result = 0;

View File

@ -55,18 +55,8 @@ extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v);
// Internal routine for use by fallback path of GetVarint32Ptr
extern const char* GetVarint32PtrFallback(const char* p, const char* limit,
uint32* value);
inline const char* GetVarint32Ptr(const char* p, const char* limit,
uint32* value) {
if (p < limit) {
uint32 result = *(reinterpret_cast<const unsigned char*>(p));
if ((result & 128) == 0) {
*value = result;
return p + 1;
}
}
return GetVarint32PtrFallback(p, limit, value);
}
extern const char* GetVarint32Ptr(const char* p, const char* limit,
uint32* value);
extern char* EncodeVarint32(char* dst, uint32 v);
extern char* EncodeVarint64(char* dst, uint64 v);

View File

@ -22251,6 +22251,89 @@ op {
}
is_stateful: true
}
op {
name: "EnqueueTPUEmbeddingSparseBatch"
input_arg {
name: "sample_indices"
type_attr: "T1"
number_attr: "N"
}
input_arg {
name: "embedding_indices"
type_attr: "T2"
number_attr: "N"
}
input_arg {
name: "aggregation_weights"
type_attr: "T3"
number_attr: "N"
}
input_arg {
name: "mode_override"
type: DT_STRING
}
attr {
name: "T1"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T2"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T3"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "N"
type: "int"
has_minimum: true
minimum: 1
}
attr {
name: "device_ordinal"
type: "int"
default_value {
i: -1
}
}
attr {
name: "combiners"
type: "list(string)"
default_value {
list {
}
}
}
is_stateful: true
}
op {
name: "EnqueueTPUEmbeddingSparseTensorBatch"
input_arg {
@ -22299,6 +22382,93 @@ op {
}
is_stateful: true
}
op {
name: "EnqueueTPUEmbeddingSparseTensorBatch"
input_arg {
name: "sample_indices"
type_attr: "T1"
number_attr: "N"
}
input_arg {
name: "embedding_indices"
type_attr: "T2"
number_attr: "N"
}
input_arg {
name: "aggregation_weights"
type_attr: "T3"
number_attr: "N"
}
input_arg {
name: "mode_override"
type: DT_STRING
}
attr {
name: "T1"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T2"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T3"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "N"
type: "int"
has_minimum: true
minimum: 1
}
attr {
name: "device_ordinal"
type: "int"
default_value {
i: -1
}
}
attr {
name: "combiners"
type: "list(string)"
default_value {
list {
}
}
}
attr {
name: "table_ids"
type: "list(int)"
}
is_stateful: true
}
op {
name: "EnsureShape"
input_arg {
@ -22814,6 +22984,37 @@ op {
minimum: 1
}
}
op {
name: "ExperimentalAutoShardDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "num_workers"
type: DT_INT64
}
input_arg {
name: "index"
type: DT_INT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
}
op {
name: "ExperimentalBytesProducedStatsDataset"
input_arg {

View File

@ -17,6 +17,15 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("ExperimentalAutoShardDataset")
.Input("input_dataset: variant")
.Input("num_workers: int64")
.Input("index: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalBytesProducedStatsDataset")
.Input("input_dataset: variant")
.Input("tag: string")

View File

@ -10421,23 +10421,62 @@ op {
name: "EnqueueTPUEmbeddingSparseBatch"
input_arg {
name: "sample_indices"
type: DT_INT32
type_attr: "T1"
number_attr: "N"
}
input_arg {
name: "embedding_indices"
type: DT_INT32
type_attr: "T2"
number_attr: "N"
}
input_arg {
name: "aggregation_weights"
type: DT_FLOAT
type_attr: "T3"
number_attr: "N"
}
input_arg {
name: "mode_override"
type: DT_STRING
}
attr {
name: "T1"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T2"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T3"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "N"
type: "int"
@ -10465,23 +10504,62 @@ op {
name: "EnqueueTPUEmbeddingSparseTensorBatch"
input_arg {
name: "sample_indices"
type: DT_INT32
type_attr: "T1"
number_attr: "N"
}
input_arg {
name: "embedding_indices"
type: DT_INT32
type_attr: "T2"
number_attr: "N"
}
input_arg {
name: "aggregation_weights"
type: DT_FLOAT
type_attr: "T3"
number_attr: "N"
}
input_arg {
name: "mode_override"
type: DT_STRING
}
attr {
name: "T1"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T2"
type: "type"
default_value {
type: DT_INT32
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T3"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "N"
type: "int"
@ -10806,6 +10884,37 @@ op {
minimum: 1
}
}
op {
name: "ExperimentalAutoShardDataset"
input_arg {
name: "input_dataset"
type: DT_VARIANT
}
input_arg {
name: "num_workers"
type: DT_INT64
}
input_arg {
name: "index"
type: DT_INT64
}
output_arg {
name: "handle"
type: DT_VARIANT
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
}
op {
name: "ExperimentalBytesProducedStatsDataset"
input_arg {

View File

@ -393,10 +393,13 @@ REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
.Input("sample_indices: N * int32")
.Input("embedding_indices: N * int32")
.Input("aggregation_weights: N * float32")
.Input("sample_indices: N * T1")
.Input("embedding_indices: N * T2")
.Input("aggregation_weights: N * T3")
.Input("mode_override: string")
.Attr("T1: {int32,int64} = DT_INT32")
.Attr("T2: {int32,int64} = DT_INT32")
.Attr("T3: {float32,float64} = DT_FLOAT")
.Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")
@ -416,10 +419,13 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
});
REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
.Input("sample_indices: N * int32")
.Input("embedding_indices: N * int32")
.Input("aggregation_weights: N * float32")
.Input("sample_indices: N * T1")
.Input("embedding_indices: N * T2")
.Input("aggregation_weights: N * T3")
.Input("mode_override: string")
.Attr("T1: {int32,int64} = DT_INT32")
.Attr("T2: {int32,int64} = DT_INT32")
.Attr("T3: {float32,float64} = DT_FLOAT")
.Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")

View File

@ -70,8 +70,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:platform_base",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@protobuf_archive//:protobuf_headers",
],
)

View File

@ -16,9 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
#define TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
#include "google/protobuf/duration.pb.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"
@ -60,20 +58,6 @@ class StringErrorCollector : public protobuf::io::ErrorCollector {
const int index_offset_;
};
// Converts an absl::Duration to a google::protobuf::Duration.
inline google::protobuf::Duration ToDurationProto(absl::Duration duration) {
google::protobuf::Duration proto;
proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration));
proto.set_nanos(
absl::IDivDuration(duration, absl::Nanoseconds(1), &duration));
return proto;
}
// Converts a google::protobuf::Duration to an absl::Duration.
inline absl::Duration FromDurationProto(google::protobuf::Duration proto) {
return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
}
} // namespace proto_utils
} // namespace tensorflow

View File

@ -7570,6 +7570,21 @@ func OptionalGetValue(scope *Scope, optional tf.Output, output_types []tf.DataTy
return components
}
// Returns true if and only if the given Optional variant has a value.
func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "OptionalHasValue",
Input: []tf.Input{
optional,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Outputs a tensor containing the reduction across all input tensors.
//
// Outputs a tensor containing the reduction across all input tensors passed to ops
@ -10562,6 +10577,38 @@ func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output
return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
}
// Creates a dataset that shards the input dataset.
//
// Creates a dataset that shards the input dataset by num_workers, returning a
// sharded dataset for the index-th worker. This attempts to automatically shard
// a dataset by examining the Dataset graph and inserting a shard op before the
// inputs to a reader Dataset (e.g. CSVDataset, TFRecordDataset).
//
// This dataset will throw a NotFound error if we cannot shard the dataset
// automatically.
//
// Arguments:
// input_dataset: A variant tensor representing the input dataset.
// num_workers: A scalar representing the number of workers to distribute this dataset across.
// index: A scalar representing the index of the current worker out of num_workers.
//
//
func ExperimentalAutoShardDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
opspec := tf.OpSpec{
Type: "ExperimentalAutoShardDataset",
Input: []tf.Input{
input_dataset, num_workers, index,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
type RandomStandardNormalAttr func(optionalAttr)
@ -38809,18 +38856,3 @@ func OptionalNone(scope *Scope) (optional tf.Output) {
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Returns true if and only if the given Optional variant has a value.
func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "OptionalHasValue",
Input: []tf.Input{
optional,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}

View File

@ -33,13 +33,12 @@ tf_cuda_library(
"//tensorflow:android": [],
"//conditions:default": ["."],
}),
deps = [
"//tensorflow/c:c_api",
] + select({
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/c:c_api",
"//tensorflow/core:all_kernels",
"//tensorflow/core:direct_session",
"//tensorflow/core:ops",

View File

@ -33,7 +33,7 @@ GEMMLOWP_URL="https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37
FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz"
CMSIS_URL="https://github.com/ARM-software/CMSIS_5/archive/5.4.0.zip"
STM32_BARE_LIB_URL="https://github.com/google/stm32_bare_lib/archive/c07d611fb0af58450c5a3e0ab4d52b47f99bc82d.zip"
10_LIB_URL="https://github.com/sifive/freedom-e-sdk/archive/baeeb8fd497a99b3c141d7494309ec2e64f19bdf.zip"
SIFIVE_FE310_LIB_URL="https://github.com/sifive/freedom-e-sdk/archive/baeeb8fd497a99b3c141d7494309ec2e64f19bdf.zip"
RISCV_TOOLCHAIN_URL="https://static.dev.sifive.com/dev-tools/riscv64-unknown-elf-gcc-20181030-x86_64-linux-ubuntu14.tar.gz"
AM_SDK_URL="http://s3.asia.ambiqmicro.com/downloads/AmbiqSuite-Rel2.0.0.zip"
AP3_URL="https://github.com/AmbiqMicro/TFLiteMicro_Apollo3/archive/dfbcef9a57276c087d95aab7cb234f1d4c9eaaba.zip"
@ -101,7 +101,7 @@ patch_am_sdk() {
# Workaround for bug in 2.0.0 SDK, remove once that's fixed.
sed -i -e 's/#ifndef AM_HAL_GPIO_H/#ifdef __cplusplus\nextern "C" {\n#endif\n#ifndef AM_HAL_GPIO_H/g' ${am_dir}/mcu/apollo3/hal/am_hal_gpio.h
echo "Finished preparing Apollo3 files"
}
@ -109,8 +109,8 @@ patch_kissfft() {
sed -i -E "s@#ifdef FIXED_POINT@// Patched automatically by download_dependencies.sh so default is 16 bit.\n#ifndef FIXED_POINT\n#define FIXED_POINT (16)\n#endif\n// End patch.\n\n#ifdef FIXED_POINT@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h
sed -i -E "s@#define KISS_FFT_MALLOC malloc@#define KISS_FFT_MALLOC(X) (void*)(0) /* Patched. */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h
sed -i -E "s@#define KISS_FFT_FREE free@#define KISS_FFT_FREE(X) /* Patched. */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/kiss_fft.h
sed -i -E "s@(fprintf.*\);)@/* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c
sed -i -E "s@(exit.*\);)@return; /* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c
sed -ir -E "s@(fprintf.*\);)@/* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c
sed -ir -E "s@(exit.*\);)@return; /* \1 */@g" tensorflow/lite/experimental/micro/tools/make/downloads/kissfft/tools/kiss_fftr.c
echo "Finished patching kissfft"
}

View File

@ -1,4 +1,4 @@
# TensorFlow Lite Objective-C API.
# TensorFlow Lite for Objective-C
package(default_visibility = ["//visibility:private"])
@ -83,11 +83,13 @@ ios_unit_test(
"notsan",
"nomsan",
],
deps = [":TensorFlowLiteTestsLib"],
deps = [
":TestsLib",
],
)
objc_library(
name = "TensorFlowLiteTestsLib",
name = "TestsLib",
testonly = 1,
srcs = glob([
"tests/*.m",

View File

@ -1,4 +1,4 @@
# TensorFlow Lite Objective-C Library
# TensorFlow Lite for Objective-C
[TensorFlow Lite](https://www.tensorflow.org/lite/) is TensorFlow's lightweight
solution for Objective-C developers. It enables low-latency inference of
@ -44,9 +44,11 @@ bazel test tensorflow/lite/experimental/objc:TensorFlowLiteTests
### Tulsi
Open the `TensorFlowLiteObjc.tulsiproj` using the Tulsi application on Mac or by
running the following command in Terminal from the root source directory:
Open the `TensorFlowLite.tulsiproj` using the
[TulsiApp](https://github.com/bazelbuild/tulsi) or by running the
[`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
script from the root `tensorflow` directory:
```shell
generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/objc/TensorFlowLiteObjc.tulsiproj:TensorFlowLiteObjC --outputfolder ~/path/to/xcodeproj
generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/objc/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
```

View File

@ -15,7 +15,7 @@
"//tensorflow/lite/experimental/objc:TensorFlowLite",
"//tensorflow/lite/experimental/objc:TensorFlowLiteTests",
],
"projectName" : "TensorFlowLiteObjC",
"projectName" : "TensorFlowLite",
"optionSet" : {
"LaunchActionPreActionScript" : {
"p" : "$(inherited)"

View File

@ -9,7 +9,7 @@
},
}
},
"projectName" : "TensorFlowLiteObjC",
"projectName" : "TensorFlowLite",
"packages" : [
"tensorflow/lite/experimental/objc"
],

View File

@ -1,4 +1,4 @@
# TensorFlow Lite for Swift.
# TensorFlow Lite for Swift
package(default_visibility = ["//visibility:private"])
@ -11,10 +11,6 @@ load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
MINIMUM_OS_VERSION = "9.0"
SWIFT_COPTS = [
"-wmo",
]
# Default tags for filtering targets. Targets in this file are restricted to Apple platforms.
DEFAULT_TAGS = [
"apple",
@ -24,7 +20,6 @@ DEFAULT_TAGS = [
swift_library(
name = "TensorFlowLite",
srcs = glob(["Sources/*.swift"]),
copts = SWIFT_COPTS,
module_name = "TensorFlowLite",
tags = DEFAULT_TAGS,
deps = [
@ -42,31 +37,11 @@ ios_unit_test(
"nomsan",
"notsan",
],
deps = [":TensorFlowLiteTestsLib"],
)
swift_library(
name = "TensorFlowLiteTestsLib",
testonly = 1,
srcs = glob(["Tests/*.swift"]),
copts = SWIFT_COPTS,
tags = DEFAULT_TAGS,
deps = [
":TensorFlowLite",
":TestResources",
":TestsLib",
],
)
objc_library(
name = "TestResources",
data = [
"//tensorflow/lite:testdata/add.bin",
"//tensorflow/lite:testdata/add_quantized.bin",
"//tensorflow/lite:testdata/multi_add.bin",
],
tags = DEFAULT_TAGS,
)
ios_application(
name = "TensorFlowLiteApp",
app_icons = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Assets.xcassets/AppIcon.appiconset/**"]),
@ -82,25 +57,50 @@ ios_application(
"CoreGraphics",
],
tags = DEFAULT_TAGS,
deps = [":TensorFlowLiteAppLib"],
deps = [
":AppLib",
],
)
swift_library(
name = "TensorFlowLiteAppLib",
name = "TestsLib",
testonly = 1,
srcs = glob(["Tests/*.swift"]),
tags = DEFAULT_TAGS,
deps = [
":Resources",
":TensorFlowLite",
],
)
swift_library(
name = "AppLib",
srcs = glob(["TestApps/TensorFlowLiteApp/TensorFlowLiteApp/*.swift"]),
module_name = "TensorFlowLiteAppLib",
tags = DEFAULT_TAGS,
deps = [
":AppResources",
":TensorFlowLite",
":TensorFlowLiteAppResources",
],
)
objc_library(
name = "TensorFlowLiteAppResources",
name = "Resources",
data = [
"//tensorflow/lite:testdata/add.bin",
"//tensorflow/lite:testdata/add_quantized.bin",
"//tensorflow/lite:testdata/multi_add.bin",
],
tags = DEFAULT_TAGS,
)
objc_library(
name = "AppResources",
data = glob([
"TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Base.lproj/*.storyboard",
]),
tags = DEFAULT_TAGS,
deps = [":TestResources"],
deps = [
":Resources",
],
)

View File

@ -52,10 +52,12 @@ Note that `--swiftcopt=-enable-testing` is required for optimized builds (`-c op
### Tulsi
Open the `TensorFlowLite.tulsiproj` using the [TulsiApp](https://github.com/bazelbuild/tulsi) or by
running the [`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
script:
Open the `TensorFlowLite.tulsiproj` using the
[TulsiApp](https://github.com/bazelbuild/tulsi)
or by running the
[`generate_xcodeproj.sh`](https://github.com/bazelbuild/tulsi/blob/master/src/tools/generate_xcodeproj.sh)
script from the root `tensorflow` directory:
```shell
generate_xcodeproj.sh --genconfig tensorflow/lite/swift/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
generate_xcodeproj.sh --genconfig tensorflow/lite/experimental/swift/TensorFlowLite.tulsiproj:TensorFlowLite --outputfolder ~/path/to/generated/TensorFlowLite.xcodeproj
```

View File

@ -12,114 +12,108 @@ upper_tabs:
lower_tabs:
# Subsite tabs
other:
- name: Guide
- name: "Guide"
contents:
- title: Overview
path: /lite/overview
- title: Developer guide
path: /lite/devguide
- title: Android demo app
path: /lite/demo_android
- title: iOS demo app
path: /lite/demo_ios
- break: true
- title: TensorFlow Lite inference
path: /lite/inference
- title: Custom operators
path: /lite/custom_operators
- title: TensorFlow Lite ops versioning
path: /lite/ops_versioning
- title: TensorFlow Lite compatibility guide
path: /lite/tf_ops_compatibility
- title: TensorFlow Lite for iOS
path: /lite/ios
- title: TensorFlow Lite for Raspberry Pi
path: /lite/rpi
- title: "TensorFlow Lite guide"
path: /lite/guide
- heading: TF Lite converter
- title: Overview
- heading: "Get started"
- title: "Overview"
path: /lite/guide/get_started
- title: "Android quickstart"
path: /lite/guide/android
- title: "iOS quickstart"
path: /lite/guide/ios
- title: "TensorFlow Lite FAQ"
path: /lite/guide/faq
- heading: "Convert a model"
- title: "TensorFlow Lite converter"
path: /lite/convert/
- title: Python API guide
path: /lite/convert/python_api
- title: Command line examples
- title: "Command line examples"
path: /lite/convert/cmdline_examples
- title: Command line reference
- title: "Command line reference"
path: /lite/convert/cmdline_reference
- title: "Python API"
path: /lite/convert/python_api
- heading: Performance
- title: Best practices
- heading: "Inference"
- title: "Overview"
path: /lite/guide/inference
- title: "Custom operators"
path: /lite/guide/ops_custom
- title: "Operator versions"
path: /lite/guide/ops_version
- title: "Operator compatibility"
path: /lite/guide/ops_compatibility
- title: "Select operators from TensorFlow"
path: /lite/guide/ops_select
- title: "List of hosted models"
path: /lite/guide/hosted_models
- heading: "Performance"
- title: "Best practices"
path: /lite/performance/best_practices
- title: Benchmarks
- title: "Benchmarks"
path: /lite/performance/benchmarks
- title: Model optimization
- title: "Model optimization"
path: /lite/performance/model_optimization
- title: Post-training quantization
- title: "Post-training quantization"
path: /lite/performance/post_training_quantization
- title: Post-training quantization example
- title: "Post-training quantization example"
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb
status: external
- title: GPU delegate
- title: "GPU delegate"
path: /lite/performance/gpu
- title: Advanced GPU
- title: "Advanced GPU"
path: /lite/performance/gpu_advanced
- title: TF Mobile
style: accordion
status: deprecated
section:
- title: Overview
path: /lite/tfmobile/
- title: Building TensorFlow on Android
path: /lite/tfmobile/android_build
- title: Building TensorFlow on iOS
path: /lite/tfmobile/ios_build
- title: Integrating TensorFlow libraries
path: /lite/tfmobile/linking_libs
- title: Preparing models for mobile deployment
path: /lite/tfmobile/prepare_models
- title: Optimizing for mobile
path: /lite/tfmobile/optimizing
- heading: "Build TensorFlow Lite"
- title: "Build for iOS"
path: /lite/guide/build_ios
- title: "Build for ARM64"
path: /lite/guide/build_arm64
- title: "Build for Raspberry Pi"
path: /lite/guide/build_rpi
- name: Examples
- name: "Examples"
contents:
- title: Examples
- title: "Examples"
path: /lite/examples
- name: Models
- name: "Models"
contents:
- title: Overview
- title: "Overview"
path: /lite/models/
- title: Hosted models
path: /lite/models/hosted
- title: Image classification
- title: "Image classification"
section:
- title: Overview
- title: "Overview"
path: /lite/models/image_classification/overview
- title: Android
- title: "Android"
path: /lite/models/image_classification/android
- title: iOS
- title: "iOS"
path: /lite/models/image_classification/ios
- title: Object detection
- title: "Object detection"
section:
- title: Overview
- title: "Overview"
path: /lite/models/object_detection/overview
- title: Pose estimation
- title: "Pose estimation"
section:
- title: Overview
- title: "Overview"
path: /lite/models/pose_estimation/overview
- title: Segmentation
- title: "Segmentation"
section:
- title: Overview
- title: "Overview"
path: /lite/models/segmentation/overview
- title: Smart reply
- title: "Smart reply"
section:
- title: Overview
- title: "Overview"
path: /lite/models/smart_reply/overview
- name: API
- name: "API"
skip_translation: true
contents:
- title: API
- title: "API"
path: /api_docs/python/tf/lite
- include: /api_docs/_upper_tabs_api.yaml

View File

@ -1,4 +1,4 @@
# Converter command-line examples
# Converter command line examples
This page shows how to use the TensorFlow Lite Converter in the command line.

View File

@ -1,4 +1,4 @@
# Converter command-line reference
# Converter command line reference
This page is complete reference of command-line flags used by the TensorFlow
Lite Converter's command line starting from TensorFlow 1.9 up until the most

View File

@ -1,4 +1,4 @@
# TensorFlow Lite Converter
# TensorFlow Lite converter
TensorFlow Lite uses the optimized
[FlatBuffer](https://google.github.io/flatbuffers/) format to represent graphs.

View File

@ -1,5 +1,4 @@
# Android Demo App
# Android quickstart
An example Android application using TensorFLow Lite is available
[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java/demo).

View File

@ -1,4 +1,4 @@
# TensorFlow Lite for generic ARM64 boards
# Build TensorFlow Lite for ARM64 boards
## Cross compiling

View File

@ -1,4 +1,4 @@
# TensorFlow Lite for Raspberry Pi
# Build TensorFlow Lite for Raspberry Pi
## Cross compiling

View File

@ -1,4 +1,4 @@
# TF Lite Developer Guide
# Get started with TensorFlow Lite
Using a TensorFlow Lite model in your mobile app requires multiple
considerations: you must choose a pre-trained or custom model, convert the model

View File

@ -1,5 +1,5 @@
# Introduction to TensorFlow Lite
# TensorFlow Lite guide
TensorFlow Lite is TensorFlows lightweight solution for mobile and embedded
devices. It enables on-device machine learning inference with low latency and a

View File

@ -1,4 +1,4 @@
# TensorFlow Lite Inference
# TensorFlow Lite inference
[TOC]

View File

@ -1,4 +1,4 @@
# iOS Demo App
# iOS quickstart
This tutorial provides a simple iOS mobile application to classify images using
the iOS device camera. In this tutorial, you will download the demo application

View File

@ -1,4 +1,4 @@
# TensorFlow Lite & TensorFlow Compatibility Guide
# TensorFlow Lite and TensorFlow operator compatibility
TensorFlow Lite supports a number of TensorFlow operations used in common
inference models. As they are processed by the TensorFlow Lite Optimizing

View File

@ -1,4 +1,6 @@
# [Experimental] Using TensorFlow Lite with select TensorFlow ops
# Select TensorFlow operators to use in TensorFlow Lite
Caution: This feature is experimental.
The TensorFlow Lite builtin op library has grown rapidly, and will continue to
grow, but there remains a long tail of TensorFlow ops that are not yet natively
@ -196,9 +198,7 @@ Python support is actively under development.
When using a mixture of both builtin and select TensorFlow ops, all of the same
TensorFlow Lite optimizations and optimized builtin kernels will be be available
and usable with the converted model. For the TensorFlow ops, performance should
generally be comparable to that of
[TensorFlow Mobile](https://www.tensorflow.org/lite/tfmobile/).
and usable with the converted model.
The following table describes the average time taken to run inference on
MobileNet on a Pixel 2. The listed times are an average of 100 runs. These

View File

@ -1,5 +1,4 @@
# TensorFlow Lite Ops Versioning
# TensorFlow Lite operator versions
This document describes TensorFlow Lite's op versioning schema. Op
versioning enables developers to add new functionalities and parameters into

View File

@ -1,5 +1,4 @@
# Performance
# Performance benchmarks
This document lists TensorFlow Lite performance benchmarks when running well
known models on some Android and iOS devices.

View File

@ -1,4 +1,4 @@
# TensorFlow Lite GPU Delegate Tutorial
# TensorFlow Lite GPU delegate
[TensorFlow Lite](https://www.tensorflow.org/lite) supports several hardware
accelerators. This document describes how to preview the experimental GPU backend using the

View File

@ -1,195 +0,0 @@
# Building TensorFlow on Android
Warning: We expect to deprecate TensorFlow Mobile in early 2019
<div class="caution">
<p>
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
working hard to close the feature gap between TensorFlow Mobile and
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
will give ample notice to our users when we get to that point and will
provide help and support to ensure easy migrations.
</p>
<p>
In the meantime, please use TensorFlow Lite. If you have a feature request,
such as a missing op, please post to our <a
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
</p>
</div>
To get you started working with TensorFlow on Android, we'll walk through two
ways to build our TensorFlow mobile demos and deploying them on an Android
device. The first is Android Studio, which lets you build and deploy in an
IDE. The second is building with Bazel and deploying with ADB on the command
line.
Why choose one or the other of these methods?
The simplest way to use TensorFlow on Android is to use Android Studio. If you
aren't planning to customize your TensorFlow build at all, or if you want to use
Android Studio's editor and other features to build an app and just want to add
TensorFlow to it, we recommend using Android Studio.
If you are using custom ops, or have some other reason to build TensorFlow from
scratch, scroll down and see our instructions
for [building the demo with Bazel](#build_the_demo_using_bazel).
## Build the demo using Android Studio
**Prerequisites**
If you haven't already, do the following two things:
- Install [Android Studio](https://developer.android.com/studio/index.html),
following the instructions on their website.
- Clone the TensorFlow repository from GitHub:
git clone https://github.com/tensorflow/tensorflow
**Building**
1. Open Android Studio, and from the Welcome screen, select **Open an existing
Android Studio project**.
2. From the **Open File or Project** window that appears, navigate to and select
the `tensorflow/examples/android` directory from wherever you cloned the
TensorFlow GitHub repo. Click OK.
If it asks you to do a Gradle Sync, click OK.
You may also need to install various platforms and tools, if you get
errors like "Failed to find target with hash string 'android-23' and similar.
3. Open the `build.gradle` file (you can go to **1:Project** in the side panel
and find it under the **Gradle Scripts** zippy under **Android**). Look for
the `nativeBuildSystem` variable and set it to `none` if it isn't already:
// set to 'bazel', 'cmake', 'makefile', 'none'
def nativeBuildSystem = 'none'
4. Click the *Run* button (the green arrow) or select *Run > Run 'android'* from the
top menu. You may need to rebuild the project using *Build > Rebuild Project*.
If it asks you to use Instant Run, click **Proceed Without Instant Run**.
Also, you need to have an Android device plugged in with developer options
enabled at this
point. See [here](https://developer.android.com/studio/run/device.html) for
more details on setting up developer devices.
This installs three apps on your phone that are all part of the TensorFlow
Demo. See [Android Sample Apps](#android_sample_apps) for more information about
them.
## Adding TensorFlow to your apps using Android Studio
To add TensorFlow to your own apps on Android, the simplest way is to add the
following lines to your Gradle build file:
allprojects {
repositories {
jcenter()
}
}
dependencies {
implementation 'org.tensorflow:tensorflow-android:+'
}
This automatically downloads the latest stable version of TensorFlow as an AAR
and installs it in your project.
## Build the demo using Bazel
Another way to use TensorFlow on Android is to build an APK
using [Bazel](https://bazel.build/) and load it onto your device
using [ADB](https://developer.android.com/studio/command-line/adb.html). This
requires some knowledge of build systems and Android developer tools, but we'll
guide you through the basics here.
- First, follow our instructions for
<a href="http://www.tensorflow.org/install/source">installing from sources</a>.
This will also guide you through installing Bazel and cloning the
TensorFlow code.
- Download the Android [SDK](https://developer.android.com/studio/index.html)
and [NDK](https://developer.android.com/ndk/downloads/index.html) if you do
not already have them. You need at least version 12b of the NDK, and 23 of the
SDK.
- In your copy of the TensorFlow source, update the
[WORKSPACE](https://github.com/tensorflow/tensorflow/blob/master/WORKSPACE)
file with the location of your SDK and NDK, where it says &lt;PATH_TO_NDK&gt;
and &lt;PATH_TO_SDK&gt;.
- Run Bazel to build the demo APK:
bazel build -c opt //tensorflow/examples/android:tensorflow_demo
- Use [ADB](https://developer.android.com/studio/command-line/adb.html#move) to
install the APK onto your device:
adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk
Note: In general when compiling for Android with Bazel you need
`--config=android` on the Bazel command line, though in this case this
particular example is Android-only, so you don't need it here.
This installs three apps on your phone that are all part of the TensorFlow
Demo. See [Android Sample Apps](#android_sample_apps) for more information about
them.
## Android Sample Apps
The
[Android example code](https://www.tensorflow.org/code/tensorflow/examples/android/) is
a single project that builds and installs three sample apps which all use the
same underlying code. The sample apps all take video input from a phone's
camera:
- **TF Classify** uses the Inception v3 model to label the objects its pointed
at with classes from Imagenet. There are only 1,000 categories in Imagenet,
which misses most everyday objects and includes many things youre unlikely to
encounter often in real life, so the results can often be quite amusing. For
example theres no person category, so instead it will often guess things it
does know that are often associated with pictures of people, like a seat belt
or an oxygen mask. If you do want to customize this example to recognize
objects you care about, you can use
the
[TensorFlow for Poets codelab](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0) as
an example for how to train a model based on your own data.
- **TF Detect** uses a multibox model to try to draw bounding boxes around the
locations of people in the camera. These boxes are annotated with the
confidence for each detection result. Results will not be perfect, as this
kind of object detection is still an active research topic. The demo also
includes optical tracking for when objects move between frames, which runs
more frequently than the TensorFlow inference. This improves the user
experience since the apparent frame rate is faster, but it also gives the
ability to estimate which boxes refer to the same object between frames, which
is important for counting objects over time.
- **TF Stylize** implements a real-time style transfer algorithm on the camera
feed. You can select which styles to use and mix between them using the
palette at the bottom of the screen, and also switch out the resolution of the
processing to go higher or lower rez.
When you build and install the demo, you'll see three app icons on your phone,
one for each of the demos. Tapping on them should open up the app and let you
explore what they do. You can enable profiling statistics on-screen by tapping
the volume up button while theyre running.
### Android Inference Library
Because Android apps need to be written in Java, and core TensorFlow is in C++,
TensorFlow has a JNI library to interface between the two. Its interface is aimed
only at inference, so it provides the ability to load a graph, set up inputs,
and run the model to calculate particular outputs. You can see the full
documentation for the minimal set of methods in
[TensorFlowInferenceInterface.java](https://www.tensorflow.org/code/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java)
The demos applications use this interface, so theyre a good place to look for
example usage. You can download prebuilt binary jars
at
[ci.tensorflow.org](https://ci.tensorflow.org/view/Nightly/job/nightly-android/).

View File

@ -1,298 +0,0 @@
# Overview
Warning: We expect to deprecate TensorFlow Mobile in early 2019
<div class="caution">
<p>
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
working hard to close the feature gap between TensorFlow Mobile and
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
will give ample notice to our users when we get to that point and will
provide help and support to ensure easy migrations.
</p>
<p>
In the meantime, please use TensorFlow Lite. If you have a feature request,
such as a missing op, please post to our <a
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
</p>
</div>
TensorFlow was designed to be a good deep learning solution for mobile
platforms. Currently we have two solutions for deploying machine learning
applications on mobile and embedded devices: TensorFlow for Mobile and
<a href="../../lite">TensorFlow Lite</a>.
## TensorFlow Lite versus TensorFlow Mobile
Here are a few of the differences between the two:
- TensorFlow Lite is an evolution of TensorFlow Mobile. In most cases, apps
developed with TensorFlow Lite will have a smaller binary size, fewer
dependencies, and better performance.
- TensorFlow Lite is in developer preview, so not all use cases are covered yet.
We expect you to use TensorFlow Mobile to cover production cases.
- TensorFlow Lite supports only a limited set of operators, so not all models
will work on it by default. TensorFlow for Mobile has a fuller set of
supported functionality.
TensorFlow Lite provides better performance and a small binary size on mobile
platforms as well as the ability to leverage hardware acceleration if available
on their platforms. In addition, it has many fewer dependencies so it can be
built and hosted on simpler, more constrained device scenarios. TensorFlow Lite
also allows targeting accelerators through the [Neural Networks
API](https://developer.android.com/ndk/guides/neuralnetworks/index.html).
TensorFlow Lite currently has coverage for a limited set of operators. While
TensorFlow for Mobile supports only a constrained set of ops by default, in
principle if you use an arbitrary operator in TensorFlow, it can be customized
to build that kernel. Thus use cases which are not currently supported by
TensorFlow Lite should continue to use TensorFlow for Mobile. As TensorFlow Lite
evolves, it will gain additional operators, and the decision will be easier to
make.
## Introduction to TensorFlow Mobile
TensorFlow was designed from the ground up to be a good deep learning solution
for mobile platforms like Android and iOS. This mobile guide should help you
understand how machine learning can work on mobile platforms and how to
integrate TensorFlow into your mobile apps effectively and efficiently.
## About this Guide
This guide is aimed at developers who have a TensorFlow model thats
successfully working in a desktop environment, who want to integrate it into
a mobile application, and cannot use TensorFlow Lite. Here are the
main challenges youll face during that process:
- Understanding how to use Tensorflow for mobile.
- Building TensorFlow for your platform.
- Integrating the TensorFlow library into your application.
- Preparing your model file for mobile deployment.
- Optimizing for latency, RAM usage, model file size, and binary size.
## Common use cases for mobile machine learning
**Why run TensorFlow on mobile?**
Traditionally, deep learning has been associated with data centers and giant
clusters of high-powered GPU machines. However, it can be very expensive and
time-consuming to send all of the data a device has access to across a network
connection. Running on mobile makes it possible to deliver very interactive
applications in a way thats not possible when you have to wait for a network
round trip.
Here are some common use cases for on-device deep learning:
### Speech Recognition
There are a lot of interesting applications that can be built with a
speech-driven interface, and many of these require on-device processing. Most of
the time a user isnt giving commands, and so streaming audio continuously to a
remote server would be a waste of bandwidth, since it would mostly be silence or
background noises. To solve this problem its common to have a small neural
network running on-device
[listening out for a particular keyword](../tutorials/sequences/audio_recognition).
Once that keyword has been spotted, the rest of the
conversation can be transmitted over to the server for further processing if
more computing power is needed.
### Image Recognition
It can be very useful for a mobile app to be able to make sense of a camera
image. If your users are taking photos, recognizing whats in them can help your
camera apps apply appropriate filters, or label the photos so theyre easily
findable. Its important for embedded applications too, since you can use image
sensors to detect all sorts of interesting conditions, whether its spotting
endangered animals in the wild
or
[reporting how late your train is running](https://svds.com/tensorflow-image-recognition-raspberry-pi/).
TensorFlow comes with several examples of recognizing the types of objects
inside images along with a variety of different pre-trained models, and they can
all be run on mobile devices. You can try out
our
[Tensorflow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0) and
[Tensorflow for Poets 2: Optimize for Mobile](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/index.html#0) codelabs to
see how to take a pretrained model and run some very fast and lightweight
training to teach it to recognize specific objects, and then optimize it to
run on mobile.
### Object Localization
Sometimes its important to know where objects are in an image as well as what
they are. There are lots of augmented reality use cases that could benefit a
mobile app, such as guiding users to the right component when offering them
help fixing their wireless network or providing informative overlays on top of
landscape features. Embedded applications often need to count objects that are
passing by them, whether its pests in a field of crops, or people, cars and
bikes going past a street lamp.
TensorFlow offers a pretrained model for drawing bounding boxes around people
detected in images, together with tracking code to follow them over time. The
tracking is especially important for applications where youre trying to count
how many objects are present over time, since it gives you a good idea when a
new object enters or leaves the scene. We have some sample code for this
available for Android [on
GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android),
and also a [more general object detection
model](https://github.com/tensorflow/models/tree/master/research/object_detection/README.md)
available as well.
### Gesture Recognition
It can be useful to be able to control applications with hand or other
gestures, either recognized from images or through analyzing accelerometer
sensor data. Creating those models is beyond the scope of this guide, but
TensorFlow is an effective way of deploying them.
### Optical Character Recognition
Google Translates live camera view is a great example of how effective
interactive on-device detection of text can be.
<div class="video-wrapper">
<iframe class="devsite-embedded-youtube-video" data-video-id="06olHmcJjS0"
data-autohide="1" data-showinfo="0" frameborder="0" allowfullscreen>
</iframe>
</div>
There are multiple steps involved in recognizing text in images. You first have
to identify the areas where the text is present, which is a variation on the
object localization problem, and can be solved with similar techniques. Once you
have an area of text, you then need to interpret it as letters, and then use a
language model to help guess what words they represent. The simplest way to
estimate what letters are present is to segment the line of text into individual
letters, and then apply a simple neural network to the bounding box of each. You
can get good results with the kind of models used for MNIST, which you can find
in TensorFlows tutorials, though you may want a higher-resolution input. A
more advanced alternative is to use an LSTM model to process a whole line of
text at once, with the model itself handling the segmentation into different
characters.
### Translation
Translating from one language to another quickly and accurately, even if you
dont have a network connection, is an important use case. Deep networks are
very effective at this sort of task, and you can find descriptions of a lot of
different models in the literature. Often these are sequence-to-sequence
recurrent models where youre able to run a single graph to do the whole
translation, without needing to run separate parsing stages.
### Text Classification
If you want to suggest relevant prompts to users based on what theyre typing or
reading, it can be very useful to understand the meaning of the text. This is
where text classification comes in. Text classification is an umbrella term
that covers everything from sentiment analysis to topic discovery. Youre likely
to have your own categories or labels that you want to apply, so the best place
to start is with an example
like
[Skip-Thoughts](https://github.com/tensorflow/models/tree/master/research/skip_thoughts/),
and then train on your own examples.
### Voice Synthesis
A synthesized voice can be a great way of giving users feedback or aiding
accessibility, and recent advances such as
[WaveNet](https://deepmind.com/blog/wavenet-generative-model-raw-audio/) show
that deep learning can offer very natural-sounding speech.
## Mobile machine learning and the cloud
These examples of use cases give an idea of how on-device networks can
complement cloud services. Cloud has a great deal of computing power in a
controlled environment, but running on devices can offer higher interactivity.
In situations where the cloud is unavailable, or your cloud capacity is limited,
you can provide an offline experience, or reduce cloud workload by processing
easy cases on device.
Doing on-device computation can also signal when it's time to switch to working
on the cloud. A good example of this is hotword detection in speech. Since
devices are able to constantly listen out for the keywords, this then triggers a
lot of traffic to cloud-based speech recognition once one is recognized. Without
the on-device component, the whole application wouldnt be feasible, and this
pattern exists across several other applications as well. Recognizing that some
sensor input is interesting enough for further processing makes a lot of
interesting products possible.
## What hardware and software should you have?
TensorFlow runs on Ubuntu Linux, Windows 10, and OS X. For a list of all
supported operating systems and instructions to install TensorFlow, see
<a href="https://www.tensorflow.org/install">Installing Tensorflow</a>.
Note that some of the sample code we provide for mobile TensorFlow requires you
to compile TensorFlow from source, so youll need more than just `pip install`
to work through all the sample code.
To try out the mobile examples, youll need a device set up for development,
using
either [Android Studio](https://developer.android.com/studio/install.html),
or [XCode](https://developer.apple.com/xcode/) if you're developing for iOS.
## What should you do before you get started?
Before thinking about how to get your solution on mobile:
1. Determine whether your problem is solvable by mobile machine learning
2. Create a labelled dataset to define your problem
3. Pick an effective model for the problem
We'll discuss these in more detail below.
### Is your problem solvable by mobile machine learning?
Once you have an idea of the problem you want to solve, you need to make a plan
of how to build your solution. The most important first step is making sure that
your problem is actually solvable, and the best way to do that is to mock it up
using humans in the loop.
For example, if you want to drive a robot toy car using voice commands, try
recording some audio from the device and listen back to it to see if you can
make sense of whats being said. Often youll find there are problems in the
capture process, such as the motor drowning out speech or not being able to hear
at a distance, and you should tackle these problems before investing in the
modeling process.
Another example would be giving photos taken from your app to people see if they
can classify whats in them, in the way youre looking for. If they cant do
that (for example, trying to estimate calories in food from photos may be
impossible because all white soups look the same), then youll need to redesign
your experience to cope with that. A good rule of thumb is that if a human cant
handle the task then it will be difficult to train a computer to do better.
### Create a labelled dataset
After youve solved any fundamental issues with your use case, you need to
create a labeled dataset to define what problem youre trying to solve. This
step is extremely important, more than picking which model to use. You want it
to be as representative as possible of your actual use case, since the model
will only be effective at the task you teach it. Its also worth investing in
tools to make labeling the data as efficient and accurate as possible. For
example, if youre able to switch from having to click a button on a web
interface to simple keyboard shortcuts, you may be able to speed up the
generation process a lot. You should also start by doing the initial labeling
yourself, so you can learn about the difficulties and likely errors, and
possibly change your labeling or data capture process to avoid them. Once you
and your team are able to consistently label examples (that is once you
generally agree on the same labels for most examples), you can then try and
capture your knowledge in a manual and teach external raters how to run the same
process.
### Pick an effective model
The next step is to pick an effective model to use. You might be able to avoid
training a model from scratch if someone else has already implemented a model
similar to what you need; we have a repository of models implemented in
TensorFlow [on GitHub](https://github.com/tensorflow/models) that you can look
through. Lean towards the simplest model you can find, and try to get started as
soon as you have even a small amount of labelled data, since youll get the best
results when youre able to iterate quickly. The shorter the time it takes to
try training a model and running it in its real application, the better overall
results youll see. Its common for an algorithm to get great training accuracy
numbers but then fail to be useful within a real application because theres a
mismatch between the dataset and real usage. Prototype end-to-end usage as soon
as possible to create a consistent user experience.

View File

@ -1,124 +0,0 @@
# Building TensorFlow on iOS
Warning: We expect to deprecate TensorFlow Mobile in early 2019
<div class="caution">
<p>
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
working hard to close the feature gap between TensorFlow Mobile and
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
will give ample notice to our users when we get to that point and will
provide help and support to ensure easy migrations.
</p>
<p>
In the meantime, please use TensorFlow Lite. If you have a feature request,
such as a missing op, please post to our <a
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
</p>
</div>
## Using CocoaPods
The simplest way to get started with TensorFlow on iOS is using the CocoaPods
package management system. You can add the `TensorFlow-experimental` pod to your
Podfile, which installs a universal binary framework. This makes it easy to get
started but has the disadvantage of being hard to customize, which is important
in case you want to shrink your binary size. If you do need the ability to
customize your libraries, see later sections on how to do that.
## Creating your own app
If you'd like to add TensorFlow capabilities to your own app, do the following:
- Create your own app or load your already-created app in XCode.
- Add a file named Podfile at the project root directory with the following content:
target 'YourProjectName'
pod 'TensorFlow-experimental'
- Run `pod install` to download and install the `TensorFlow-experimental` pod.
- Open `YourProjectName.xcworkspace` and add your code.
- In your app's **Build Settings**, make sure to add `$(inherited)` to the
**Other Linker Flags**, and **Header Search Paths** sections.
## Running the Samples
You'll need Xcode 7.3 or later to run our iOS samples.
There are currently three examples: simple, benchmark, and camera. For now, you
can download the sample code by cloning the main tensorflow repository (we are
planning to make the samples available as a separate repository later).
From the root of the tensorflow folder, download [Inception
v1](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip),
and extract the label and graph files into the data folders inside both the
simple and camera examples using these steps:
mkdir -p ~/graphs
curl -o ~/graphs/inception5h.zip \
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip \
&& unzip ~/graphs/inception5h.zip -d ~/graphs/inception5h
cp ~/graphs/inception5h/* tensorflow/examples/ios/benchmark/data/
cp ~/graphs/inception5h/* tensorflow/examples/ios/camera/data/
cp ~/graphs/inception5h/* tensorflow/examples/ios/simple/data/
Change into one of the sample directories, download the
[Tensorflow-experimental](https://cocoapods.org/pods/TensorFlow-experimental)
pod, and open the Xcode workspace. Note that installing the pod can take a long
time since it is big (~450MB). If you want to run the simple example, then:
cd tensorflow/examples/ios/simple
pod install
open tf_simple_example.xcworkspace # note .xcworkspace, not .xcodeproj
# this is created by pod install
Run the simple app in the XCode simulator. You should see a single-screen app
with a **Run Model** button. Tap that, and you should see some debug output
appear below indicating that the example Grace Hopper image in directory data
has been analyzed, with a military uniform recognized.
Run the other samples using the same process. The camera example requires a real
device connected. Once you build and run that, you should get a live camera view
that you can point at objects to get real-time recognition results.
### iOS Example details
There are three demo applications for iOS, all defined in Xcode projects inside
[tensorflow/examples/ios](https://www.tensorflow.org/code/tensorflow/examples/ios/).
- **Simple**: This is a minimal example showing how to load and run a TensorFlow
model in as few lines as possible. It just consists of a single view with a
button that executes the model loading and inference when its pressed.
- **Camera**: This is very similar to the Android TF Classify demo. It loads
Inception v3 and outputs its best label estimate for whats in the live camera
view. As with the Android version, you can train your own custom model using
TensorFlow for Poets and drop it into this example with minimal code changes.
- **Benchmark**: is quite close to Simple, but it runs the graph repeatedly and
outputs similar statistics to the benchmark tool on Android.
### Troubleshooting
- Make sure you use the TensorFlow-experimental pod (and not TensorFlow).
- The TensorFlow-experimental pod is current about ~450MB. The reason it is so
big is because we are bundling multiple platforms, and the pod includes all
TensorFlow functionality (e.g. operations). The final app size after build is
substantially smaller though (~25MB). Working with the complete pod is
convenient during development, but see below section on how you can build your
own custom TensorFlow library to reduce the size.
## Building the TensorFlow iOS libraries from source
While Cocoapods is the quickest and easiest way of getting started, you sometimes
need more flexibility to determine which parts of TensorFlow your app should be
shipped with. For such cases, you can build the iOS libraries from the
sources. [This
guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/ios#building-the-tensorflow-ios-libraries-from-source)
contains detailed instructions on how to do that.

View File

@ -1,270 +0,0 @@
# Integrating TensorFlow libraries
Warning: We expect to deprecate TensorFlow Mobile in early 2019
<div class="caution">
<p>
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
working hard to close the feature gap between TensorFlow Mobile and
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
will give ample notice to our users when we get to that point and will
provide help and support to ensure easy migrations.
</p>
<p>
In the meantime, please use TensorFlow Lite. If you have a feature request,
such as a missing op, please post to our <a
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
</p>
</div>
Once you have made some progress on a model that addresses the problem youre
trying to solve, its important to test it out inside your application
immediately. There are often unexpected differences between your training data
and what users actually encounter in the real world, and getting a clear picture
of the gap as soon as possible improves the product experience.
This page talks about how to integrate the TensorFlow libraries into your own
mobile applications, once you have already successfully built and deployed the
TensorFlow mobile demo apps.
## Linking the library
After you've managed to build the examples, you'll probably want to call
TensorFlow from one of your existing applications. The very easiest way to do
this is to use the Pod installation steps described in
<a href="./ios_build.md">Building TensorFlow on iOS</a>, but if you want to build
TensorFlow from source (for example to customize which operators are included)
you'll need to break out TensorFlow as a framework, include the right header
files, and link against the built libraries and dependencies.
### Android
For Android, you just need to link in a Java library contained in a JAR file
called `libandroid_tensorflow_inference_java.jar`. There are three ways to
include this functionality in your program:
1. Include the jcenter AAR which contains it, as in this
[example app](https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tfmobile/build.gradle#L59-L65)
2. Download the nightly precompiled version from
[ci.tensorflow.org](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/).
3. Build the JAR file yourself using the instructions [in our Android GitHub repo](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android)
### iOS
Pulling in the TensorFlow libraries on iOS is a little more complicated. Here is
a checklist of what youll need to do to your iOS app:
- Link against tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a, usually
by adding `-L/your/path/tensorflow/contrib/makefile/gen/lib/` and
`-ltensorflow-core` to your linker flags.
- Link against the generated protobuf libraries by adding
`-L/your/path/tensorflow/contrib/makefile/gen/protobuf_ios/lib` and
`-lprotobuf` and `-lprotobuf-lite` to your command line.
- For the include paths, you need the root of your TensorFlow source folder as
the first entry, followed by
`tensorflow/contrib/makefile/downloads/protobuf/src`,
`tensorflow/contrib/makefile/downloads`,
`tensorflow/contrib/makefile/downloads/eigen`, and
`tensorflow/contrib/makefile/gen/proto`.
- Make sure your binary is built with `-force_load` (or the equivalent on your
platform), aimed at the TensorFlow library to ensure that its linked
correctly. More detail on why this is necessary can be found in the next
section, [Global constructor magic](#global_constructor_magic). On Linux-like
platforms, youll need different flags, more like
`-Wl,--allow-multiple-definition -Wl,--whole-archive`.
Youll also need to link in the Accelerator framework, since this is used to
speed up some of the operations.
## Global constructor magic
One of the subtlest problems you may run up against is the “No session factory
registered for the given session options” error when trying to call TensorFlow
from your own application. To understand why this is happening and how to fix
it, you need to know a bit about the architecture of TensorFlow.
The framework is designed to be very modular, with a thin core and a large
number of specific objects that are independent and can be mixed and matched as
needed. To enable this, the coding pattern in C++ had to let modules easily
notify the framework about the services they offer, without requiring a central
list that has to be updated separately from each implementation. It also had to
allow separate libraries to add their own implementations without needing a
recompile of the core.
To achieve this capability, TensorFlow uses a registration pattern in a lot of
places. In the code, it looks like this:
```
class MulKernel : OpKernel {
Status Compute(OpKernelContext* context) { … }
};
REGISTER_KERNEL(MulKernel, “Mul”);
```
This would be in a standalone `.cc` file linked into your application, either
as part of the main set of kernels or as a separate custom library. The magic
part is that the `REGISTER_KERNEL()` macro is able to inform the core of
TensorFlow that it has an implementation of the Mul operation, so that it can be
called in any graphs that require it.
From a programming point of view, this setup is very convenient. The
implementation and registration code live in the same file, and adding new
implementations is as simple as compiling and linking it in. The difficult part
comes from the way that the `REGISTER_KERNEL()` macro is implemented. C++
doesnt offer a good mechanism for doing this sort of registration, so we have
to resort to some tricky code. Under the hood, the macro is implemented so that
it produces something like this:
```
class RegisterMul {
public:
RegisterMul() {
global_kernel_registry()->Register(“Mul”, [](){
return new MulKernel()
});
}
};
RegisterMul g_register_mul;
```
This sets up a class `RegisterMul` with a constructor that tells the global
kernel registry what function to call when somebody asks it how to create a
“Mul” kernel. Then theres a global object of that class, and so the constructor
should be called at the start of any program.
While this may sound sensible, the unfortunate part is that the global object
thats defined is not used by any other code, so linkers not designed with this
in mind will decide that it can be deleted. As a result, the constructor is
never called, and the class is never registered. All sorts of modules use this
pattern in TensorFlow, and it happens that `Session` implementations are the
first to be looked for when the code is run, which is why it shows up as the
characteristic error when this problem occurs.
The solution is to force the linker to not strip any code from the library, even
if it believes its unused. On iOS, this step can be accomplished with the
`-force_load` flag, specifying a library path, and on Linux you need
`--whole-archive`. These persuade the linker to not be as aggressive about
stripping, and should retain the globals.
The actual implementation of the various `REGISTER_*` macros is a bit more
complicated in practice, but they all suffer the same underlying problem. If
youre interested in how they work, [op_kernel.h](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_kernel.h#L1091)
is a good place to start investigating.
## Protobuf problems
TensorFlow relies on
the [Protocol Buffer](https://developers.google.com/protocol-buffers/) library,
commonly known as protobuf. This library takes definitions of data structures
and produces serialization and access code for them in a variety of
languages. The tricky part is that this generated code needs to be linked
against shared libraries for the exact same version of the framework that was
used for the generator. This can be an issue when `protoc`, the tool used to
generate the code, is from a different version of protobuf than the libraries in
the standard linking and include paths. For example, you might be using a copy
of `protoc` that was built locally in `~/projects/protobuf-3.0.1.a`, but you have
libraries installed at `/usr/local/lib` and `/usr/local/include` that are from
3.0.0.
The symptoms of this issue are errors during the compilation or linking phases
with protobufs. Usually, the build tools take care of this, but if youre using
the makefile, make sure youre building the protobuf library locally and using
it, as shown in [this Makefile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/Makefile#L18).
Another situation that can cause problems is when protobuf headers and source
files need to be generated as part of the build process. This process makes
building more complex, since the first phase has to be a pass over the protobuf
definitions to create all the needed code files, and only after that can you go
ahead and do a build of the library code.
### Multiple versions of protobufs in the same app
Protobufs generate headers that are needed as part of the C++ interface to the
overall TensorFlow library. This complicates using the library as a standalone
framework.
If your application is already using version 1 of the protocol buffers library,
you may have trouble integrating TensorFlow because it requires version 2. If
you just try to link both versions into the same binary, youll see linking
errors because some of the symbols clash. To solve this particular problem, we
have an experimental script at [rename_protobuf.sh](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/rename_protobuf.sh).
You need to run this as part of the makefile build, after youve downloaded all
the dependencies:
```
tensorflow/contrib/makefile/download_dependencies.sh
tensorflow/contrib/makefile/rename_protobuf.sh
```
## Calling the TensorFlow API
Once you have the framework available, you then need to call into it. The usual
pattern is that you first load your model, which represents a preset set of
numeric computations, and then you run inputs through that model (for example,
images from a camera) and receive outputs (for example, predicted labels).
On Android, we provide the Java Inference Library that is focused on just this
use case, while on iOS and Raspberry Pi you call directly into the C++ API.
### Android
Heres what a typical Inference Library sequence looks like on Android:
```
// Load the model from disk.
TensorFlowInferenceInterface inferenceInterface =
new TensorFlowInferenceInterface(assetManager, modelFilename);
// Copy the input data into TensorFlow.
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
// Run the inference call.
inferenceInterface.run(outputNames, logStats);
// Copy the output Tensor back into the output array.
inferenceInterface.fetch(outputName, outputs);
```
You can find the source of this code in the [Android examples](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java#L107).
### iOS and Raspberry Pi
Heres the equivalent code for iOS and Raspberry Pi:
```
// Load the model.
PortableReadFileToProto(file_path, &tensorflow_graph);
// Create a session from the model.
tensorflow::Status s = session->Create(tensorflow_graph);
if (!s.ok()) {
LOG(FATAL) << "Could not create TensorFlow Graph: " << s;
}
// Run the model.
std::string input_layer = "input";
std::string output_layer = "output";
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = session->Run({\{input_layer, image_tensor}},
{output_layer}, {}, &outputs);
if (!run_status.ok()) {
LOG(FATAL) << "Running model failed: " << run_status;
}
// Access the output data.
tensorflow::Tensor* output = &outputs[0];
```
This is all based on the
[iOS sample code](https://www.tensorflow.org/code/tensorflow/examples/ios/simple/RunModelViewController.mm),
but theres nothing iOS-specific; the same code should be usable on any platform
that supports C++.
You can also find specific examples for Raspberry Pi
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/pi_examples/label_image/label_image.cc).

View File

@ -1,518 +0,0 @@
# Optimizing for mobile
Warning: We expect to deprecate TensorFlow Mobile in early 2019
<div class="caution">
<p>
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
working hard to close the feature gap between TensorFlow Mobile and
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
will give ample notice to our users when we get to that point and will
provide help and support to ensure easy migrations.
</p>
<p>
In the meantime, please use TensorFlow Lite. If you have a feature request,
such as a missing op, please post to our <a
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
</p>
</div>
There are some special issues that you have to deal with when youre trying to
ship on mobile or embedded devices, and youll need to think about these as
youre developing your model.
These issues are:
- Model and Binary Size
- App speed and model loading speed
- Performance and threading
We'll discuss a few of these below.
## What are the minimum device requirements for TensorFlow?
You need at least one megabyte of program memory and several megabytes of RAM to
run the base TensorFlow runtime, so its not suitable for DSPs or
microcontrollers. Other than those, the biggest constraint is usually the
calculation speed of the device, and whether you can run the model you need for
your application with a low enough latency. You can use the benchmarking tools
in [How to Profile your Model](#how_to_profile_your_model) to get an idea of how
many FLOPs are required for a model, and then use that to make rule-of-thumb
estimates of how fast they will run on different devices. For example, a modern
smartphone might be able to run 10 GFLOPs per second, so the best you could hope
for from a 5 GFLOP model is two frames per second, though you may do worse
depending on what the exact computation patterns are.
This model dependence means that its possible to run TensorFlow even on very
old or constrained phones, as long as you optimize your network to fit within
the latency budget and possibly within limited RAM too. For memory usage, you
mostly need to make sure that the intermediate buffers that TensorFlow creates
arent too large, which you can examine in the benchmark output too.
## Speed
One of the highest priorities of most model deployments is figuring out how to
run the inference fast enough to give a good user experience. The first place to
start is by looking at the total number of floating point operations that are
required to execute the graph. You can get a very rough estimate of this by
using the `benchmark_model` tool:
bazel build -c opt tensorflow/tools/benchmark:benchmark_model && \
bazel-bin/tensorflow/tools/benchmark/benchmark_model \
--graph=/tmp/inception_graph.pb --input_layer="Mul:0" \
--input_layer_shape="1,299,299,3" --input_layer_type="float" \
--output_layer="softmax:0" --show_run_order=false --show_time=false \
--show_memory=false --show_summary=true --show_flops=true --logtostderr
This should show you an estimate of how many operations are needed to run the
graph. You can then use that information to figure out how feasible your model
is to run on the devices youre targeting. For an example, a high-end phone from
2016 might be able to do 20 billion FLOPs per second, so the best speed you
could hope for from a model that requires 10 billion FLOPs is around 500ms. On a
device like the Raspberry Pi 3 that can do about 5 billion FLOPs, you may only
get one inference every two seconds.
Having this estimate helps you plan for what youll be able to realistically
achieve on a device. If the model is using too many ops, then there are a lot of
opportunities to optimize the architecture to reduce that number.
Advanced techniques include [SqueezeNet](https://arxiv.org/abs/1602.07360)
and [MobileNet](https://arxiv.org/abs/1704.04861), which are architectures
designed to produce models for mobile -- lean and fast but with a small accuracy
cost. You can also just look at alternative models, even older ones, which may
be smaller. For example, Inception v1 only has around 7 million parameters,
compared to Inception v3s 24 million, and requires only 3 billion FLOPs rather
than 9 billion for v3.
## Model Size
Models that run on a device need to be stored somewhere on the device, and very
large neural networks can be hundreds of megabytes. Most users are reluctant to
download very large app bundles from app stores, so you want to make your model
as small as possible. Furthermore, smaller neural networks can persist in and
out of a mobile device's memory faster.
To understand how large your network will be on disk, start by looking at the
size on disk of your `GraphDef` file after youve run `freeze_graph` and
`strip_unused_nodes` on it (see <a href="./prepare_models.md">Preparing models</a> for
more details on these tools), since then it should only contain
inference-related nodes. To double-check that your results are as expected, run
the `summarize_graph` tool to see how many parameters are in constants:
bazel build tensorflow/tools/graph_transforms:summarize_graph && \
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/tmp/tensorflow_inception_graph.pb
That command should give you output that looks something like this:
No inputs spotted.
Found 1 possible outputs: (name=softmax, op=Softmax)
Found 23885411 (23.89M) const parameters, 0 (0) variable parameters,
and 99 control_edges
Op types used: 489 Const, 99 CheckNumerics, 99 Identity, 94
BatchNormWithGlobalNormalization, 94 Conv2D, 94 Relu, 11 Concat, 9 AvgPool,
5 MaxPool, 1 Sub, 1 Softmax, 1 ResizeBilinear, 1 Reshape, 1 Mul, 1 MatMul,
1 ExpandDims, 1 DecodeJpeg, 1 Cast, 1 BiasAdd
The important part for our current purposes is the number of const
parameters. In most models these will be stored as 32-bit floats to start, so if
you multiply the number of const parameters by four, you should get something
thats close to the size of the file on disk. You can often get away with only
eight-bits per parameter with very little loss of accuracy in the final result,
so if your file size is too large you can try using
<a href="https://www.tensorflow.org/performance/quantization">quantize_weights</a>
to transform the parameters down.
bazel build tensorflow/tools/graph_transforms:transform_graph && \
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=/tmp/tensorflow_inception_optimized.pb \
--out_graph=/tmp/tensorflow_inception_quantized.pb \
--inputs='Mul:0' --outputs='softmax:0' --transforms='quantize_weights'
If you look at the resulting file size, you should see that its about a quarter
of the original at 23MB.
Another transform is `round_weights`, which doesn't make the file smaller, but it
makes the file compressible to about the same size as when `quantize_weights` is
used. This is particularly useful for mobile development, taking advantage of
the fact that app bundles are compressed before theyre downloaded by consumers.
The original file does not compress well with standard algorithms, because the
bit patterns of even very similar numbers can be very different. The
`round_weights` transform keeps the weight parameters stored as floats, but
rounds them to a set number of step values. This means there are a lot more
repeated byte patterns in the stored model, and so compression can often bring
the size down dramatically, in many cases to near the size it would be if they
were stored as eight bit.
Another advantage of `round_weights` is that the framework doesnt have to
allocate a temporary buffer to unpack the parameters into, as we have to when
we just use `quantize_weights`. This saves a little bit of latency (though the
results should be cached so its only costly on the first run) and makes it
possible to use memory mapping, as described later.
## Binary Size
One of the biggest differences between mobile and server development is the
importance of binary size. On desktop machines its not unusual to have
executables that are hundreds of megabytes on disk, but for mobile and embedded
apps its vital to keep the binary as small as possible so that user downloads
are easy. As mentioned above, TensorFlow only includes a subset of op
implementations by default, but this still results in a 12 MB final
executable. To reduce this, you can set up the library to only include the
implementations of the ops that you actually need, based on automatically
analyzing your model. To use it:
- Run `tools/print_required_ops/print_selective_registration_header.py` on your
model to produce a header file that only enables the ops it uses.
- Place the `ops_to_register.h` file somewhere that the compiler can find
it. This can be in the root of your TensorFlow source folder.
- Build TensorFlow with `SELECTIVE_REGISTRATION` defined, for example by passing
in `--copts=”-DSELECTIVE_REGISTRATION”` to your Bazel build command.
This process recompiles the library so that only the needed ops and types are
included, which can dramatically reduce the executable size. For example, with
Inception v3, the new size is only 1.5MB.
## How to Profile your Model
Once you have an idea of what your device's peak performance range is, its
worth looking at its actual current performance. Using a standalone TensorFlow
benchmark, rather than running it inside a larger app, helps isolate just the
Tensorflow contribution to the
latency. The
[tensorflow/tools/benchmark](https://www.tensorflow.org/code/tensorflow/tools/benchmark/) tool
is designed to help you do this. To run it on Inception v3 on your desktop
machine, build this benchmark model:
bazel build -c opt tensorflow/tools/benchmark:benchmark_model && \
bazel-bin/tensorflow/tools/benchmark/benchmark_model \
--graph=/tmp/tensorflow_inception_graph.pb --input_layer="Mul" \
--input_layer_shape="1,299,299,3" --input_layer_type="float" \
--output_layer="softmax:0" --show_run_order=false --show_time=false \
--show_memory=false --show_summary=true --show_flops=true --logtostderr
You should see output that looks something like this:
<pre>
============================== Top by Computation Time ==============================
[node
type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [Name]
Conv2D 22.859 14.212 13.700 4.972% 4.972% 3871.488 conv_4/Conv2D
Conv2D 8.116 8.964 11.315 4.106% 9.078% 5531.904 conv_2/Conv2D
Conv2D 62.066 16.504 7.274 2.640% 11.717% 443.904 mixed_3/conv/Conv2D
Conv2D 2.530 6.226 4.939 1.792% 13.510% 2765.952 conv_1/Conv2D
Conv2D 55.585 4.605 4.665 1.693% 15.203% 313.600 mixed_2/tower/conv_1/Conv2D
Conv2D 127.114 5.469 4.630 1.680% 16.883% 81.920 mixed_10/conv/Conv2D
Conv2D 47.391 6.994 4.588 1.665% 18.548% 313.600 mixed_1/tower/conv_1/Conv2D
Conv2D 39.463 7.878 4.336 1.574% 20.122% 313.600 mixed/tower/conv_1/Conv2D
Conv2D 127.113 4.192 3.894 1.413% 21.535% 114.688 mixed_10/tower_1/conv/Conv2D
Conv2D 70.188 5.205 3.626 1.316% 22.850% 221.952 mixed_4/conv/Conv2D
============================== Summary by node type ==============================
[Node type] [count] [avg ms] [avg %] [cdf %] [mem KB]
Conv2D 94 244.899 88.952% 88.952% 35869.953
BiasAdd 95 9.664 3.510% 92.462% 35873.984
AvgPool 9 7.990 2.902% 95.364% 7493.504
Relu 94 5.727 2.080% 97.444% 35869.953
MaxPool 5 3.485 1.266% 98.710% 3358.848
Const 192 1.727 0.627% 99.337% 0.000
Concat 11 1.081 0.393% 99.730% 9892.096
MatMul 1 0.665 0.242% 99.971% 4.032
Softmax 1 0.040 0.015% 99.986% 4.032
<> 1 0.032 0.012% 99.997% 0.000
Reshape 1 0.007 0.003% 100.000% 0.000
Timings (microseconds): count=50 first=330849 curr=274803 min=232354 max=415352 avg=275563 std=44193
Memory (bytes): count=50 curr=128366400(all same)
514 nodes defined 504 nodes observed
</pre>
This is the summary view, which is enabled by the show_summary flag. To
interpret it, the first table is a list of the nodes that took the most time, in
order by how long they took. From left to right, the columns are:
- Node type, what kind of operation this was.
- Start time of the op, showing where it falls in the sequence of operations.
- First time in milliseconds. This is how long the operation took on the first
run of the benchmark, since by default 20 runs are executed to get more
reliable statistics. The first time is useful to spot which ops are doing
expensive calculations on the first run, and then caching the results.
- Average time for the operation across all runs, in milliseconds.
- What percentage of the total time for one run the op took. This is useful to
understand where the hotspots are.
- The cumulative total time of this and the previous ops in the table. This is
handy for understanding what the distribution of work is across the layers, to
see if just a few of the nodes are taking up most of the time.
- The amount of memory consumed by outputs of this type of op.
- Name of the node.
The second table is similar, but instead of breaking down the timings by
particular named nodes, it groups them by the kind of op. This is very useful to
understand which op implementations you might want to optimize or eliminate from
your graph. The table is arranged with the most costly operations at the start,
and only shows the top ten entries, with a placeholder for other nodes. The
columns from left to right are:
- Type of the nodes being analyzed.
- Accumulated average time taken by all nodes of this type, in milliseconds.
- What percentage of the total time was taken by this type of operation.
- Cumulative time taken by this and op types higher in the table, so you can
understand the distribution of the workload.
- How much memory the outputs of this op type took up.
Both of these tables are set up so that you can easily copy and paste their
results into spreadsheet documents, since they are output with tabs as
separators between the columns. The summary by node type can be the most useful
when looking for optimization opportunities, since its a pointer to the code
thats taking the most time. In this case, you can see that the Conv2D ops are
almost 90% of the execution time. This is a sign that the graph is pretty
optimal, since convolutions and matrix multiplies are expected to be the bulk of
a neural networks computing workload.
As a rule of thumb, its more worrying if you see a lot of other operations
taking up more than a small fraction of the time. For neural networks, the ops
that dont involve large matrix multiplications should usually be dwarfed by the
ones that do, so if you see a lot of time going into those its a sign that
either your network is non-optimally constructed, or the code implementing those
ops is not as optimized as it could
be. [Performance bugs](https://github.com/tensorflow/tensorflow/issues) or
patches are always welcome if you do encounter this situation, especially if
they include an attached model exhibiting this behavior and the command line
used to run the benchmark tool on it.
The run above was on your desktop, but the tool also works on Android, which is
where its most useful for mobile development. Heres an example command line to
run it on a 64-bit ARM device:
bazel build -c opt --config=android_arm64 \
tensorflow/tools/benchmark:benchmark_model
adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp
adb push /tmp/tensorflow_inception_graph.pb /data/local/tmp/
adb shell '/data/local/tmp/benchmark_model \
--graph=/data/local/tmp/tensorflow_inception_graph.pb --input_layer="Mul" \
--input_layer_shape="1,299,299,3" --input_layer_type="float" \
--output_layer="softmax:0" --show_run_order=false --show_time=false \
--show_memory=false --show_summary=true'
You can interpret the results in exactly the same way as the desktop version
above. If you have any trouble figuring out what the right input and output
names and types are, take a look at the
<a href="./prepare_models">Preparing models</a>
page for details about detecting these for your model, and look at the
`summarize_graph` tool which may give you
helpful information.
There isnt good support for command line tools on iOS, so instead theres a
separate example
at
[tensorflow/examples/ios/benchmark](https://www.tensorflow.org/code/tensorflow/examples/ios/benchmark) that
packages the same functionality inside a standalone app. This outputs the
statistics to both the screen of the device and the debug log. If you want
on-screen statistics for the Android example apps, you can turn them on by
pressing the volume-up button.
## Profiling within your own app
The output you see from the benchmark tool is generated from modules that are
included as part of the standard TensorFlow runtime, which means you have access
to them within your own applications too. You can see an example of how to do
that [here](https://www.tensorflow.org/code/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm?l=139).
The basic steps are:
1. Create a StatSummarizer object:
tensorflow::StatSummarizer stat_summarizer(tensorflow_graph);
2. Set up the options:
tensorflow::RunOptions run_options;
run_options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
tensorflow::RunMetadata run_metadata;
3. Run the graph:
run_status = session->Run(run_options, inputs, output_layer_names, {},
output_layers, &run_metadata);
4. Calculate the results and print them out:
assert(run_metadata.has_step_stats());
const tensorflow::StepStats& step_stats = run_metadata.step_stats();
stat_summarizer->ProcessStepStats(step_stats);
stat_summarizer->PrintStepStats();
## Visualizing Models
The most effective way to speed up your code is by altering your model so it
does less work. To do that, you need to understand what your model is doing, and
visualizing it is a good first step. To get a high-level overview of your graph,
use [TensorBoard](https://github.com/tensorflow/tensorboard).
## Threading
The desktop version of TensorFlow has a sophisticated threading model, and will
try to run multiple operations in parallel if it can. In our terminology this is
called “inter-op parallelism” (though to avoid confusion with “intra-op”, you
could think of it as “between-op” instead), and can be set by specifying
`inter_op_parallelism_threads` in the session options.
By default, mobile devices run operations serially; that is,
`inter_op_parallelism_threads` is set to 1. Mobile processors usually have few
cores and a small cache, so running multiple operations accessing disjoint parts
of memory usually doesnt help performance. “Intra-op parallelism” (or
“within-op”) can be very helpful though, especially for computation-bound
operations like convolutions where different threads can feed off the same small
set of memory.
On mobile, how many threads an op will use is set to the number of cores by
default, or 2 when the number of cores can't be determined. You can override the
default number of threads that ops are using by setting
`intra_op_parallelism_threads` in the session options. Its a good idea to
reduce the default if your app has its own threads doing heavy processing, so
that they dont interfere with each other.
To see more details on session options, look at [ConfigProto](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
## Retrain with mobile data
The biggest cause of accuracy problems when running models on mobile apps is
unrepresentative training data. For example, most of the Imagenet photos are
well-framed so that the object is in the center of the picture, well-lit, and
shot with a normal lens. Photos from mobile devices are often poorly framed,
badly lit, and can have fisheye distortions, especially selfies.
The solution is to expand your training set with data actually captured from
your application. This step can involve extra work, since youll have to label
the examples yourself, but even if you just use it to expand your original
training data, it can help the training set dramatically. Improving the training
set by doing this, and by fixing other quality issues like duplicates or badly
labeled examples is the single best way to improve accuracy. Its usually a
bigger help than altering your model architecture or using different techniques.
## Reducing model loading time and/or memory footprint
Most operating systems allow you to load a file using memory mapping, rather
than going through the usual I/O APIs. Instead of allocating an area of memory
on the heap and then copying bytes from disk into it, you simply tell the
operating system to make the entire contents of a file appear directly in
memory. This has several advantages:
* Speeds loading
* Reduces paging (increases performance)
* Does not count towards RAM budget for your app
TensorFlow has support for memory mapping the weights that form the bulk of most
model files. Because of limitations in the `ProtoBuf` serialization format, we
have to make a few changes to our model loading and processing code. The
way memory mapping works is that we have a single file where the first part is a
normal `GraphDef` serialized into the protocol buffer wire format, but then the
weights are appended in a form that can be directly mapped.
To create this file, run the
`tensorflow/contrib/util:convert_graphdef_memmapped_format` tool. This takes in
a `GraphDef` file thats been run through `freeze_graph` and converts it to the
format that has the weights appended at the end. Since that files no longer a
standard `GraphDef` protobuf, you then need to make some changes to the loading
code. You can see an example of this in
the
[iOS Camera demo app](https://www.tensorflow.org/code/tensorflow/examples/ios/camera/tensorflow_utils.mm?l=147),
in the `LoadMemoryMappedModel()` function.
The same code (with the Objective C calls for getting the filenames substituted)
can be used on other platforms too. Because were using memory mapping, we need
to start by creating a special TensorFlow environment object thats set up with
the file well be using:
std::unique_ptr<tensorflow::MemmappedEnv> memmapped_env;
memmapped_env->reset(
new tensorflow::MemmappedEnv(tensorflow::Env::Default()));
tensorflow::Status mmap_status =
(memmapped_env->get())->InitializeFromFile(file_path);
You then need to pass in this environment to subsequent calls, like this one for
loading the graph:
tensorflow::GraphDef tensorflow_graph;
tensorflow::Status load_graph_status = ReadBinaryProto(
memmapped_env->get(),
tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&tensorflow_graph);
You also need to create the session with a pointer to the environment youve
created:
tensorflow::SessionOptions options;
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(::tensorflow::OptimizerOptions::L0);
options.env = memmapped_env->get();
tensorflow::Session* session_pointer = nullptr;
tensorflow::Status session_status =
tensorflow::NewSession(options, &session_pointer);
One thing to notice here is that were also disabling automatic optimizations,
since in some cases these will fold constant sub-trees, and so create copies of
tensor values that we dont want and use up more RAM.
Once youve gone through these steps, you can use the session and graph as
normal, and you should see a reduction in loading time and memory usage.
## Protecting model files from easy copying
By default, your models will be stored in the standard serialized protobuf
format on disk. In theory this means that anybody can copy your model, which you
may not want. However, in practice, most models are so application-specific and
obfuscated by optimizations that the risk is similar to that of competitors
disassembling and reusing your code, but if you do want to make it tougher for
casual users to access your files it is possible to take some basic steps.
Most of our examples use
the
[ReadBinaryProto()](https://www.tensorflow.org/code/tensorflow/core/platform/env.cc?q=core/platform/env.cc&l=409) convenience
call to load a `GraphDef` from disk. This does require an unencrypted protobuf on
disk. Luckily though, the implementation of the call is pretty straightforward
and it should be easy to write an equivalent that can decrypt in memory. Here's
some code that shows how you can read and decrypt a protobuf using your own
decryption routine:
Status ReadEncryptedProto(Env* env, const string& fname,
::tensorflow::protobuf::MessageLite* proto) {
string data;
TF_RETURN_IF_ERROR(ReadFileToString(env, fname, &data));
DecryptData(&data); // Your own function here.
if (!proto->ParseFromString(&data)) {
TF_RETURN_IF_ERROR(stream->status());
return errors::DataLoss("Can't parse ", fname, " as binary proto");
}
return Status::OK();
}
To use this youd need to define the DecryptData() function yourself. It could
be as simple as something like:
void DecryptData(string* data) {
for (int i = 0; i < data.size(); ++i) {
data[i] = data[i] ^ 0x23;
}
}
You may want something more complex, but exactly what youll need is outside the
current scope here.

View File

@ -1,318 +0,0 @@
# Preparing models for mobile deployment
Warning: We expect to deprecate TensorFlow Mobile in early 2019
<div class="caution">
<p>
<a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
working hard to close the feature gap between TensorFlow Mobile and
TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
will give ample notice to our users when we get to that point and will
provide help and support to ensure easy migrations.
</p>
<p>
In the meantime, please use TensorFlow Lite. If you have a feature request,
such as a missing op, please post to our <a
href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
</p>
</div>
The requirements for storing model information during training are very
different from when you want to release it as part of a mobile app. This section
covers the tools involved in converting from a training model to something
releasable in production.
## What is up with all the different saved file formats?
You may find yourself getting very confused by all the different ways that
TensorFlow can save out graphs. To help, heres a rundown of some of the
different components, and what they are used for. The objects are mostly defined
and serialized as protocol buffers:
- [NodeDef](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto):
Defines a single operation in a model. It has a unique name, a list of the
names of other nodes it pulls inputs from, the operation type it implements
(for example `Add`, or `Mul`), and any attributes that are needed to control
that operation. This is the basic unit of computation for TensorFlow, and all
work is done by iterating through a network of these nodes, applying each one
in turn. One particular operation type thats worth knowing about is `Const`,
since this holds information about a constant. This may be a single, scalar
number or string, but it can also hold an entire multi-dimensional tensor
array. The values for a `Const` are stored inside the `NodeDef`, and so large
constants can take up a lot of room when serialized.
- [Checkpoint](https://www.tensorflow.org/code/tensorflow/core/util/tensor_bundle/tensor_bundle.h). Another
way of storing values for a model is by using `Variable` ops. Unlike `Const`
ops, these dont store their content as part of the `NodeDef`, so they take up
very little space within the `GraphDef` file. Instead their values are held in
RAM while a computation is running, and then saved out to disk as checkpoint
files periodically. This typically happens as a neural network is being
trained and weights are updated, so its a time-critical operation, and it may
happen in a distributed fashion across many workers, so the file format has to
be both fast and flexible. They are stored as multiple checkpoint files,
together with metadata files that describe whats contained within the
checkpoints. When youre referring to a checkpoint in the API (for example
when passing a filename in as a command line argument), youll use the common
prefix for a set of related files. If you had these files:
/tmp/model/model-chkpt-1000.data-00000-of-00002
/tmp/model/model-chkpt-1000.data-00001-of-00002
/tmp/model/model-chkpt-1000.index
/tmp/model/model-chkpt-1000.meta
You would refer to them as `/tmp/model/chkpt-1000`.
- [GraphDef](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto):
Has a list of `NodeDefs`, which together define the computational graph to
execute. During training, some of these nodes will be `Variables`, and so if
you want to have a complete graph you can run, including the weights, youll
need to call a restore operation to pull those values from
checkpoints. Because checkpoint loading has to be flexible to deal with all of
the training requirements, this can be tricky to implement on mobile and
embedded devices, especially those with no proper file system available like
iOS. This is where
the
[`freeze_graph.py`](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py) script
comes in handy. As mentioned above, `Const` ops store their values as part of
the `NodeDef`, so if all the `Variable` weights are converted to `Const` nodes,
then we only need a single `GraphDef` file to hold the model architecture and
the weights. Freezing the graph handles the process of loading the
checkpoints, and then converts all Variables to Consts. You can then load the
resulting file in a single call, without having to restore variable values
from checkpoints. One thing to watch out for with `GraphDef` files is that
sometimes theyre stored in text format for easy inspection. These versions
usually have a .pbtxt filename suffix, whereas the binary files end with
.pb.
- [FunctionDefLibrary](https://www.tensorflow.org/code/tensorflow/core/framework/function.proto):
This appears in `GraphDef`, and is effectively a set of sub-graphs, each with
information about their input and output nodes. Each sub-graph can then be
used as an op in the main graph, allowing easy instantiation of different
nodes, in a similar way to how functions encapsulate code in other languages.
- [MetaGraphDef](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto):
A plain `GraphDef` only has information about the network of computations, but
doesnt have any extra information about the model or how it can be
used. `MetaGraphDef` contains a `GraphDef` defining the computation part of
the model, but also includes information like signatures, which are
suggestions about which inputs and outputs you may want to call the model
with, data on how and where any checkpoint files are saved, and convenience
tags for grouping ops together for ease of use.
- [SavedModel](https://www.tensorflow.org/code/tensorflow/core/protobuf/saved_model.proto):
Its common to want to have different versions of a graph that rely on a
common set of variable checkpoints. For example, you might need a GPU and a
CPU version of the same graph, but keep the same weights for both. You might
also need some extra files (like label names) as part of your
model. The
[SavedModel](https://www.tensorflow.org/code/tensorflow/python/saved_model/README.md) format
addresses these needs by letting you save multiple versions of the same graph
without duplicating variables, and also storing asset files in the same
bundle. Under the hood, it uses `MetaGraphDef` and checkpoint files, along
with extra metadata files. Its the format that youll want to use if youre
deploying a web API using TensorFlow Serving, for example.
## How do you get a model you can use on mobile?
In most situations, training a model with TensorFlow will give you a folder
containing a `GraphDef` file (usually ending with the `.pb` or `.pbtxt` extension) and
a set of checkpoint files. What you need for mobile or embedded deployment is a
single `GraphDef` file thats been frozen, or had its variables converted into
inline constants so everythings in one file. To handle the conversion, youll
need the `freeze_graph.py` script, thats held in
[`tensorflow/python/tools/freeze_graph.py`](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py). Youll run it like this:
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/model/my_graph.pb \
--input_checkpoint=/tmp/model/model.ckpt-1000 \
--output_graph=/tmp/frozen_graph.pb \
--output_node_names=output_node \
The `input_graph` argument should point to the `GraphDef` file that holds your
model architecture. Its possible that your `GraphDef` has been stored in a text
format on disk, in which case its likely to end in `.pbtxt` instead of `.pb`,
and you should add an extra `--input_binary=false` flag to the command.
The `input_checkpoint` should be the most recent saved checkpoint. As mentioned
in the checkpoint section, you need to give the common prefix to the set of
checkpoints here, rather than a full filename.
`output_graph` defines where the resulting frozen `GraphDef` will be
saved. Because its likely to contain a lot of weight values that take up a
large amount of space in text format, its always saved as a binary protobuf.
`output_node_names` is a list of the names of the nodes that you want to extract
the results of your graph from. This is needed because the freezing process
needs to understand which parts of the graph are actually needed, and which are
artifacts of the training process, like summarization ops. Only ops that
contribute to calculating the given output nodes will be kept. If you know how
your graph is going to be used, these should just be the names of the nodes you
pass into `Session::Run()` as your fetch targets. The easiest way to find the
node names is to inspect the Node objects while building your graph in python.
Inspecting your graph in TensorBoard is another simple way. You can get some
suggestions on likely outputs by running the [`summarize_graph` tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/README.md#inspecting-graphs).
Because the output format for TensorFlow has changed over time, there are a
variety of other less commonly used flags available too, like `input_saver`, but
hopefully you shouldnt need these on graphs trained with modern versions of the
framework.
## Using the Graph Transform Tool
A lot of the things you need to do to efficiently run a model on device are
available through the [Graph Transform
Tool](https://www.tensorflow.org/code/tensorflow/tools/graph_transforms/README.md). This
command-line tool takes an input `GraphDef` file, applies the set of rewriting
rules you request, and then writes out the result as a `GraphDef`. See the
documentation for more information on how to build and run this tool.
### Removing training-only nodes
TensorFlow `GraphDefs` produced by the training code contain all of the
computation thats needed for back-propagation and updates of weights, as well
as the queuing and decoding of inputs, and the saving out of checkpoints. All of
these nodes are no longer needed during inference, and some of the operations
like checkpoint saving arent even supported on mobile platforms. To create a
model file that you can load on devices you need to delete those unneeded
operations by running the `strip_unused_nodes` rule in the Graph Transform Tool.
The trickiest part of this process is figuring out the names of the nodes you
want to use as inputs and outputs during inference. You'll need these anyway
once you start to run inference, but you also need them here so that the
transform can calculate which nodes are not needed on the inference-only
path. These may not be obvious from the training code. The easiest way to
determine the node name is to explore the graph with TensorBoard.
Remember that mobile applications typically gather their data from sensors and
have it as arrays in memory, whereas training typically involves loading and
decoding representations of the data stored on disk. In the case of Inception v3
for example, theres a `DecodeJpeg` op at the start of the graph thats designed
to take JPEG-encoded data from a file retrieved from disk and turn it into an
arbitrary-sized image. After that theres a `BilinearResize` op to scale it to
the expected size, followed by a couple of other ops that convert the byte data
into float and scale the value magnitudes it in the way the rest of the graph
expects. A typical mobile app will skip most of these steps because its getting
its input directly from a live camera, so the input node you will actually
supply will be the output of the `Mul` node in this case.
<img src ="../images/inception_input.png" width="300">
Youll need to do a similar process of inspection to figure out the correct
output nodes.
If youve just been given a frozen `GraphDef` file, and are not sure about the
contents, try using the `summarize_graph` tool to print out information
about the inputs and outputs it finds from the graph structure. Heres an
example with the original Inception v3 file:
bazel run tensorflow/tools/graph_transforms:summarize_graph --
--in_graph=tensorflow_inception_graph.pb
Once you have an idea of what the input and output nodes are, you can feed them
into the graph transform tool as the `--input_names` and `--output_names`
arguments, and call the `strip_unused_nodes` transform, like this:
bazel run tensorflow/tools/graph_transforms:transform_graph --
--in_graph=tensorflow_inception_graph.pb
--out_graph=optimized_inception_graph.pb --inputs='Mul' --outputs='softmax'
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms'
One thing to look out for here is that you need to specify the size and type
that you want your inputs to be. This is because any values that youre going to
be passing in as inputs to inference need to be fed to special `Placeholder` op
nodes, and the transform may need to create them if they dont already exist. In
the case of Inception v3 for example, a `Placeholder` node replaces the old
`Mul` node that used to output the resized and rescaled image array, since were
going to be doing that processing ourselves before we call TensorFlow. It keeps
the original name though, which is why we always feed in inputs to `Mul` when we
run a session with our modified Inception graph.
After youve run this process, youll have a graph that only contains the actual
nodes you need to run your prediction process. This is the point where it
becomes useful to run metrics on the graph, so its worth running
`summarize_graph` again to understand whats in your model.
## What ops should you include on mobile?
There are hundreds of operations available in TensorFlow, and each one has
multiple implementations for different data types. On mobile platforms, the size
of the executable binary thats produced after compilation is important, because
app download bundles need to be as small as possible for the best user
experience. If all of the ops and data types are compiled into the TensorFlow
library then the total size of the compiled library can be tens of megabytes, so
by default only a subset of ops and data types are included.
That means that if you load a model file thats been trained on a desktop
machine, you may see the error “No OpKernel was registered to support Op” when
you load it on mobile. The first thing to try is to make sure youve stripped
out any training-only nodes, since the error will occur at load time even if the
op is never executed. If youre still hitting the same problem once thats done,
youll need to look at adding the op to your built library.
The criteria for including ops and types fall into several categories:
- Are they only useful in back-propagation, for gradients? Since mobile is
focused on inference, we dont include these.
- Are they useful mainly for other training needs, such as checkpoint saving?
These we leave out.
- Do they rely on frameworks that arent always available on mobile, such as
libjpeg? To avoid extra dependencies we dont include ops like `DecodeJpeg`.
- Are there types that arent commonly used? We dont include boolean variants
of ops for example, since we dont see much use of them in typical inference
graphs.
These ops are trimmed by default to optimize for inference on mobile, but it is
possible to alter some build files to change the default. After alternating the
build files, you will need to recompile TensorFlow. See below for more details
on how to do this, and also see <a href="./optimizing.md">optimizing binary size</a>
for more on reducing your binary size.
### Locate the implementation
Operations are broken into two parts. The first is the op definition, which
declares the signature of the operation, which inputs, outputs, and attributes
it has. These take up very little space, and so all are included by default. The
implementations of the op computations are done in kernels, which live in the
`tensorflow/core/kernels` folder. You need to compile the C++ file containing
the kernel implementation of the op you need into the library. To figure out
which file that is, you can search for the operation name in the source
files.
[Heres an example search in github](https://github.com/search?utf8=%E2%9C%93&q=repo%3Atensorflow%2Ftensorflow+extension%3Acc+path%3Atensorflow%2Fcore%2Fkernels+REGISTER+Mul&type=Code&ref=searchresults).
Youll see that this search is looking for the `Mul` op implementation, and it
finds it in `tensorflow/core/kernels/cwise_op_mul_1.cc`. You need to look for
macros beginning with `REGISTER`, with the op name you care about as one of the
string arguments.
In this case, the implementations are actually broken up across multiple `.cc`
files, so youd need to include all of them in your build. If youre more
comfortable using the command line for code search, heres a grep command that
also locates the right files if you run it from the root of your TensorFlow
repository:
`grep 'REGISTER.*"Mul"' tensorflow/core/kernels/*.cc`
### Add the implementation to the build
If youre using Bazel, and building for Android, youll want to add the files
youve found to
the
[`android_extended_ops_group1`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3565) or
[`android_extended_ops_group2`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3632) targets. You
may also need to include any .cc files they depend on in there. If the build
complains about missing header files, add the .hs that are needed into
the
[`android_extended_ops`](https://www.tensorflow.org/code/tensorflow/core/kernels/BUILD#L3525) target.
If youre using a makefile targeting iOS, Raspberry Pi, etc, go to
[`tensorflow/contrib/makefile/tf_op_files.txt`](https://www.tensorflow.org/code/tensorflow/contrib/makefile/tf_op_files.txt) and
add the right implementation files there.

Some files were not shown because too many files have changed in this diff Show More