STT-tensorflow/tensorflow/contrib/lite/build_def.bzl
Rohan Jain ed494f17fc Rolling back tensorflow .bzl file changes
END_PUBLIC

BEGIN_PUBLIC
Automated g4 rollback of changelist 203459720

PiperOrigin-RevId: 203501636
2018-07-06 11:17:47 -07:00

320 lines
8.8 KiB
Python

"""Generate Flatbuffer binary from json."""
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
def tflite_copts():
"""Defines compile time flags."""
copts = [
"-DFARMHASH_NO_CXX_STRING",
] + select({
str(Label("//tensorflow:android_arm64")): [
"-std=c++11",
"-O3",
],
str(Label("//tensorflow:android_arm")): [
"-mfpu=neon",
"-mfloat-abi=softfp",
"-std=c++11",
"-O3",
],
str(Label("//tensorflow:android_x86")): [
"-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
],
str(Label("//tensorflow:ios_x86_64")): [
"-msse4.1",
],
"//conditions:default": [],
}) + select({
str(Label("//tensorflow:with_default_optimizations")): [],
"//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
})
return copts
LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds"
def tflite_linkopts_unstripped():
"""Defines linker flags to reduce size of TFLite binary.
These are useful when trying to investigate the relative size of the
symbols in TFLite.
Returns:
a select object with proper linkopts
"""
return select({
"//tensorflow:android": [
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export.
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
"-Wl,--icf=all", # Identical code folding.
],
})
def tflite_jni_linkopts_unstripped():
"""Defines linker flags to reduce size of TFLite binary with JNI.
These are useful when trying to investigate the relative size of the
symbols in TFLite.
Returns:
a select object with proper linkopts
"""
return select({
"//tensorflow:android": [
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
"-Wl,--icf=all", # Identical code folding.
],
})
def tflite_linkopts():
"""Defines linker flags to reduce size of TFLite binary."""
return tflite_linkopts_unstripped() + select({
"//tensorflow:android": [
"-s", # Omit symbol table.
],
"//conditions:default": [],
})
def tflite_jni_linkopts():
"""Defines linker flags to reduce size of TFLite binary with JNI."""
return tflite_jni_linkopts_unstripped() + select({
"//tensorflow:android": [
"-s", # Omit symbol table.
"-latomic", # Required for some uses of ISO C++11 <atomic> in x86.
],
"//conditions:default": [],
})
def tflite_jni_binary(name,
copts=tflite_copts(),
linkopts=tflite_jni_linkopts(),
linkscript=LINKER_SCRIPT,
linkshared=1,
linkstatic=1,
deps=[]):
"""Builds a jni binary for TFLite."""
linkopts = linkopts + [
"-Wl,--version-script", # Export only jni functions & classes.
"$(location {})".format(linkscript),
]
native.cc_binary(
name=name,
copts=copts,
linkshared=linkshared,
linkstatic=linkstatic,
deps= deps + [linkscript],
linkopts=linkopts)
def tf_to_tflite(name, src, options, out):
"""Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
Args:
name: Name of rule.
src: name of the input graphdef file.
options: options passed to TOCO.
out: name of the output flatbuffer file.
"""
toco_cmdline = " ".join([
"//tensorflow/contrib/lite/toco:toco",
"--input_format=TENSORFLOW_GRAPHDEF",
"--output_format=TFLITE",
("--input_file=$(location %s)" % src),
("--output_file=$(location %s)" % out),
] + options )
native.genrule(
name = name,
srcs=[src],
outs=[out],
cmd = toco_cmdline,
tools= ["//tensorflow/contrib/lite/toco:toco"],
)
def tflite_to_json(name, src, out):
"""Convert a TF Lite flatbuffer to JSON.
Args:
name: Name of rule.
src: name of the input flatbuffer file.
out: name of the output JSON file.
"""
flatc = "@flatbuffers//:flatc"
schema = "//tensorflow/contrib/lite/schema:schema.fbs"
native.genrule(
name = name,
srcs = [schema, src],
outs = [out],
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
"$(location %s) --raw-binary --strict-json -t" +
" -o /tmp $(location %s) -- $${TMP}.bin &&" +
"cp $${TMP}.json $(location %s)")
% (src, flatc, schema, out),
tools = [flatc],
)
def json_to_tflite(name, src, out):
"""Convert a JSON file to TF Lite's flatbuffer.
Args:
name: Name of rule.
src: name of the input JSON file.
out: name of the output flatbuffer file.
"""
flatc = "@flatbuffers//:flatc"
schema = "//tensorflow/contrib/lite/schema:schema_fbs"
native.genrule(
name = name,
srcs = [schema, src],
outs = [out],
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
"$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
" -o /tmp $(location %s) $${TMP}.json &&" +
"cp $${TMP}.bin $(location %s)")
% (src, flatc, schema, out),
tools = [flatc],
)
# This is the master list of generated examples that will be made into tests. A
# function called make_XXX_tests() must also appear in generate_examples.py.
# Disable a test by commenting it out. If you do, add a link to a bug or issue.
def generated_test_models():
return [
"add",
"arg_max",
"avg_pool",
"batch_to_space_nd",
"concat",
"constant",
"control_dep",
"conv",
"depthwiseconv",
"div",
"equal",
"exp",
"expand_dims",
"floor",
"fully_connected",
"fused_batch_norm",
"gather",
"global_batch_norm",
"greater",
"greater_equal",
"sum",
"l2norm",
"l2_pool",
"less",
"less_equal",
"local_response_norm",
"log_softmax",
"log",
"lstm",
"max_pool",
"maximum",
"mean",
"minimum",
"mul",
"neg",
"not_equal",
"pad",
"padv2",
# "prelu",
"pow",
"relu",
"relu1",
"relu6",
"reshape",
"resize_bilinear",
"rsqrt",
"shape",
"sigmoid",
"sin",
"slice",
"softmax",
"space_to_batch_nd",
"space_to_depth",
"sparse_to_dense",
"split",
"sqrt",
"squeeze",
"strided_slice",
"strided_slice_1d_exhaustive",
"sub",
"tile",
"topk",
"transpose",
"transpose_conv",
"where",
]
def gen_zip_test(name, test_name, **kwargs):
"""Generate a zipped-example test and its dependent zip files.
Args:
name: Resulting cc_test target name
test_name: Test targets this model. Comes from the list above.
**kwargs: tf_cc_test kwargs.
"""
gen_zipped_test_file(
name = "zip_%s" % test_name,
file = "%s.zip" % test_name,
)
tf_cc_test(name, **kwargs)
def gen_zipped_test_file(name, file):
"""Generate a zip file of tests by using :generate_examples.
Args:
name: Name of output. We will produce "`file`.files" as a target.
file: The name of one of the generated_examples targets, e.g. "transpose"
"""
toco = "//tensorflow/contrib/lite/toco:toco"
native.genrule(
name = file + ".files",
cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
+ " --zip_to_output " + file + " $(@D)"),
outs = [file],
tools = [
":generate_examples",
toco,
],
)
native.filegroup(
name = name,
srcs = [file],
)
def gen_selected_ops(name, model):
"""Generate the library that includes only used ops.
Args:
name: Name of the generated library.
model: TFLite model to interpret.
"""
out = name + "_registration.cc"
tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
tflite_path = "//tensorflow/contrib/lite"
native.genrule(
name = name,
srcs = [model],
outs = [out],
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s")
% (tool, model, out, tflite_path[2:]),
tools = [tool],
)