Merge and fix workspace.bzl for cmake, makefile use.
This commit is contained in:
commit
b992ff69e2
RELEASE.mdavro.BUILDboost.BUILDbzip2.BUILDeigen.BUILDfarmhash.BUILDgif.BUILDgmock.BUILDgrpc.BUILDjpeg.BUILDjsoncpp.BUILDlinenoise.BUILDnanopb.BUILDpng.BUILDsix.BUILD
tensorflow
BUILD
c
cc
contrib
BUILD__init__.py
bayesflow/python/ops
cmake
crf
distributions/python
kernel_tests
ops
ffmpeg
learn/python/learn
linear_optimizer/kernels
makefile
Makefilebuild_all_android.shdownload_dependencies.shproto_text_cc_files.txtproto_text_pb_h_files.txttf_pb_text_files.txttf_proto_files.txt
metrics
quantization/kernels/hexagon
slim/python/slim
tensorboard
tfprof
@ -17,6 +17,9 @@
|
||||
instead of graph.proto.
|
||||
* ops.NoGradient was renamed ops.NotDifferentiable. ops.NoGradient will
|
||||
be removed soon.
|
||||
* dot.h / DotGraph was removed (it was an early analysis tool prior
|
||||
to TensorBoard, no longer that useful). It remains in history
|
||||
should someone find the code useful.
|
||||
|
||||
# Release 0.10.0
|
||||
|
||||
|
14
avro.BUILD
14
avro.BUILD
@ -2,21 +2,19 @@ package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
prefix_dir = "avro-cpp-1.8.0"
|
||||
|
||||
cc_library(
|
||||
name = "avrocpp",
|
||||
srcs = glob(
|
||||
[
|
||||
prefix_dir + "/impl/**/*.cc",
|
||||
prefix_dir + "/impl/**/*.hh",
|
||||
"impl/**/*.cc",
|
||||
"impl/**/*.hh",
|
||||
],
|
||||
exclude = [
|
||||
prefix_dir + "/impl/avrogencpp.cc",
|
||||
"impl/avrogencpp.cc",
|
||||
],
|
||||
),
|
||||
hdrs = glob([prefix_dir + "/api/**/*.hh"]),
|
||||
includes = [prefix_dir + "/api"],
|
||||
hdrs = glob(["api/**/*.hh"]),
|
||||
includes = ["api"],
|
||||
deps = [
|
||||
"@boost_archive//:boost",
|
||||
"@boost_archive//:filesystem",
|
||||
@ -27,7 +25,7 @@ cc_library(
|
||||
|
||||
cc_binary(
|
||||
name = "avrogencpp",
|
||||
srcs = [prefix_dir + "/impl/avrogencpp.cc"],
|
||||
srcs = ["impl/avrogencpp.cc"],
|
||||
deps = [
|
||||
":avrocpp",
|
||||
"@boost_archive//:program_options",
|
||||
|
26
boost.BUILD
26
boost.BUILD
@ -10,21 +10,19 @@ package(default_visibility = ["@avro_archive//:__subpackages__"])
|
||||
|
||||
licenses(["notice"]) # Boost software license
|
||||
|
||||
prefix_dir = "boost_1_61_0"
|
||||
|
||||
cc_library(
|
||||
name = "boost",
|
||||
hdrs = glob([
|
||||
prefix_dir + "/boost/**/*.hpp",
|
||||
prefix_dir + "/boost/**/*.h",
|
||||
prefix_dir + "/boost/**/*.ipp",
|
||||
"boost/**/*.hpp",
|
||||
"boost/**/*.h",
|
||||
"boost/**/*.ipp",
|
||||
]),
|
||||
includes = [prefix_dir],
|
||||
includes = ["."],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "filesystem",
|
||||
srcs = glob([prefix_dir + "/libs/filesystem/src/*.cpp"]),
|
||||
srcs = glob(["libs/filesystem/src/*.cpp"]),
|
||||
deps = [
|
||||
":boost",
|
||||
":system",
|
||||
@ -33,7 +31,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "iostreams",
|
||||
srcs = glob([prefix_dir + "/libs/iostreams/src/*.cpp"]),
|
||||
srcs = glob(["libs/iostreams/src/*.cpp"]),
|
||||
deps = [
|
||||
":boost",
|
||||
"@bzip2_archive//:bz2lib",
|
||||
@ -43,16 +41,12 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "program_options",
|
||||
srcs = glob([prefix_dir + "/libs/program_options/src/*.cpp"]),
|
||||
deps = [
|
||||
":boost",
|
||||
],
|
||||
srcs = glob(["libs/program_options/src/*.cpp"]),
|
||||
deps = [":boost"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "system",
|
||||
srcs = glob([prefix_dir + "/libs/system/src/*.cpp"]),
|
||||
deps = [
|
||||
":boost",
|
||||
],
|
||||
srcs = glob(["libs/system/src/*.cpp"]),
|
||||
deps = [":boost"],
|
||||
)
|
||||
|
42
bzip2.BUILD
42
bzip2.BUILD
@ -2,35 +2,27 @@ package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # BSD derivative
|
||||
|
||||
prefix_dir = "bzip2-1.0.6"
|
||||
|
||||
BZ2LIB_SRCS = [
|
||||
# these are in the same order as their corresponding .o files are in OBJS in
|
||||
# Makefile (rather than lexicographic order) for easy comparison (that they
|
||||
# are identical).
|
||||
"blocksort.c",
|
||||
"huffman.c",
|
||||
"crctable.c",
|
||||
"randtable.c",
|
||||
"compress.c",
|
||||
"decompress.c",
|
||||
"bzlib.c",
|
||||
]
|
||||
|
||||
cc_library(
|
||||
name = "bz2lib",
|
||||
srcs = [prefix_dir + "/" + source for source in BZ2LIB_SRCS] +
|
||||
[prefix_dir + "/bzlib_private.h"],
|
||||
hdrs = [prefix_dir + "/bzlib.h"],
|
||||
includes = [prefix_dir],
|
||||
srcs = [
|
||||
# These are in the same order as their corresponding .o files are in
|
||||
# OBJS in Makefile (rather than lexicographic order) for easy
|
||||
# comparison (that they are identical.)
|
||||
"blocksort.c",
|
||||
"huffman.c",
|
||||
"crctable.c",
|
||||
"randtable.c",
|
||||
"compress.c",
|
||||
"decompress.c",
|
||||
"bzlib.c",
|
||||
"bzlib_private.h",
|
||||
],
|
||||
hdrs = ["bzlib.h"],
|
||||
includes = ["."],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "bzip2",
|
||||
srcs = [
|
||||
"bzip2.c",
|
||||
],
|
||||
deps = [
|
||||
":bz2lib",
|
||||
],
|
||||
srcs = ["bzip2.c"],
|
||||
deps = [":bz2lib"],
|
||||
)
|
||||
|
68
eigen.BUILD
68
eigen.BUILD
@ -1,8 +1,70 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
# Description:
|
||||
# Eigen is a C++ template library for linear algebra: vectors,
|
||||
# matrices, and related algorithms.
|
||||
|
||||
licenses([
|
||||
# Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code.
|
||||
# We've taken special care to not reference any restricted code.
|
||||
"reciprocal", # MPL2
|
||||
"notice", # Portions BSD
|
||||
])
|
||||
|
||||
# License-restricted (i.e. not reciprocal or notice) files inside Eigen/...
|
||||
EIGEN_RESTRICTED_FILES = [
|
||||
"Eigen/src/OrderingMethods/Amd.h",
|
||||
"Eigen/src/SparseCholesky/**",
|
||||
]
|
||||
|
||||
# Notable transitive dependencies of restricted files inside Eigen/...
|
||||
EIGEN_RESTRICTED_DEPS = [
|
||||
"Eigen/Eigen",
|
||||
"Eigen/IterativeLinearSolvers",
|
||||
"Eigen/MetisSupport",
|
||||
"Eigen/Sparse",
|
||||
"Eigen/SparseCholesky",
|
||||
"Eigen/SparseLU",
|
||||
]
|
||||
|
||||
# Note: unsupported/Eigen is unsupported and might go away at any time.
|
||||
EIGEN_FILES = [
|
||||
"Eigen/**",
|
||||
"unsupported/Eigen/CXX11/**",
|
||||
"unsupported/Eigen/FFT",
|
||||
"unsupported/Eigen/KroneckerProduct",
|
||||
"unsupported/Eigen/src/FFT/**",
|
||||
"unsupported/Eigen/src/KroneckerProduct/**",
|
||||
"unsupported/Eigen/MatrixFunctions",
|
||||
"unsupported/Eigen/SpecialFunctions",
|
||||
"unsupported/Eigen/src/SpecialFunctions/**",
|
||||
]
|
||||
|
||||
# List of files picked up by glob but actually part of another target.
|
||||
EIGEN_EXCLUDE_FILES = [
|
||||
"Eigen/src/Core/arch/AVX/PacketMathGoogleTest.cc",
|
||||
]
|
||||
|
||||
# Files known to be under MPL2 license.
|
||||
EIGEN_MPL2_HEADER_FILES = glob(
|
||||
EIGEN_FILES,
|
||||
exclude = EIGEN_EXCLUDE_FILES +
|
||||
EIGEN_RESTRICTED_FILES +
|
||||
EIGEN_RESTRICTED_DEPS + [
|
||||
# Guarantees any file missed by excludes above will not compile.
|
||||
"Eigen/src/Core/util/NonMPL2.h",
|
||||
"Eigen/**/CMakeLists.txt",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "eigen",
|
||||
hdrs = glob(["**/*.h", "unsupported/Eigen/*", "unsupported/Eigen/CXX11/*", "Eigen/*"]),
|
||||
includes = [ '.' ],
|
||||
hdrs = EIGEN_MPL2_HEADER_FILES,
|
||||
defines = [
|
||||
# This define (mostly) guarantees we don't link any problematic
|
||||
# code. We use it, but we do not rely on it, as evidenced above.
|
||||
"EIGEN_MPL2_ONLY",
|
||||
# TODO(jart): Use EIGEN_USE_NONBLOCKING_THREAD_POOL but first add an
|
||||
# eigen_initialize.cc file and alwayslink=1.
|
||||
],
|
||||
includes = ["."],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
@ -1,21 +1,9 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
prefix_dir = "farmhash-34c13ddfab0e35422f4c3979f360635a8c050260"
|
||||
|
||||
genrule(
|
||||
name = "configure",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [prefix_dir + "/config.h"],
|
||||
),
|
||||
outs = [prefix_dir + "/config.h"],
|
||||
cmd = "pushd external/farmhash_archive/%s; workdir=$$(mktemp -d -t tmp.XXXXXXXXXX); cp -a * $$workdir; pushd $$workdir; ./configure; popd; popd; cp $$workdir/config.h $(@D); rm -rf $$workdir;" % prefix_dir,
|
||||
)
|
||||
licenses(["notice"]) # MIT
|
||||
|
||||
cc_library(
|
||||
name = "farmhash",
|
||||
srcs = [prefix_dir + "/src/farmhash.cc"],
|
||||
hdrs = [prefix_dir + "/src/farmhash.h"] + [":configure"],
|
||||
includes = [prefix_dir],
|
||||
visibility = ["//visibility:public"]
|
||||
srcs = ["farmhash.cc"],
|
||||
hdrs = ["farmhash.h"],
|
||||
includes = ["."],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
95
gif.BUILD
95
gif.BUILD
@ -1,65 +1,44 @@
|
||||
SOURCES = [
|
||||
"dgif_lib.c",
|
||||
"egif_lib.c",
|
||||
"gif_font.c",
|
||||
"gif_hash.c",
|
||||
"gifalloc.c",
|
||||
"openbsd-reallocarray.c",
|
||||
"gif_err.c",
|
||||
"quantize.c",
|
||||
]
|
||||
# Description:
|
||||
# A library for decoding and encoding GIF images
|
||||
|
||||
HEADERS = [
|
||||
"gif_hash.h",
|
||||
"gif_lib.h",
|
||||
"gif_lib_private.h",
|
||||
]
|
||||
|
||||
config_setting(
|
||||
name = "windows",
|
||||
values = {
|
||||
"cpu": "x64_windows_msvc",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
prefix_dir = "giflib-5.1.4/lib"
|
||||
prefix_dir_windows = "windows/giflib-5.1.4/lib"
|
||||
|
||||
genrule(
|
||||
name = "srcs_without_unistd",
|
||||
srcs = [prefix_dir + "/" + source for source in SOURCES],
|
||||
outs = [prefix_dir_windows + "/" + source for source in SOURCES],
|
||||
cmd = "for f in $(SRCS); do " +
|
||||
" sed 's/#include <unistd.h>//g' $$f > $(@D)/%s/$$(basename $$f);" % prefix_dir_windows +
|
||||
"done",
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "hdrs_without_unistd",
|
||||
srcs = [prefix_dir + "/" + hdrs for hdrs in HEADERS],
|
||||
outs = [prefix_dir_windows + "/" + hdrs for hdrs in HEADERS],
|
||||
cmd = "for f in $(SRCS); do " +
|
||||
" sed 's/#include <unistd.h>//g' $$f > $(@D)/%s/$$(basename $$f);" % prefix_dir_windows +
|
||||
"done",
|
||||
)
|
||||
licenses(["notice"]) # MIT
|
||||
|
||||
cc_library(
|
||||
name = "gif",
|
||||
srcs = select({
|
||||
"//conditions:default" : [prefix_dir + "/" + source for source in SOURCES],
|
||||
":windows" : [":srcs_without_unistd"],
|
||||
}),
|
||||
hdrs = select({
|
||||
"//conditions:default" : [prefix_dir + "/" + hdrs for hdrs in HEADERS],
|
||||
":windows" : [":hdrs_without_unistd"],
|
||||
}),
|
||||
includes = select({
|
||||
"//conditions:default" : [prefix_dir],
|
||||
":windows" : [prefix_dir_windows],
|
||||
}),
|
||||
defines = [
|
||||
"HAVE_CONFIG_H",
|
||||
srcs = [
|
||||
"dgif_lib.c",
|
||||
"egif_lib.c",
|
||||
"gif_err.c",
|
||||
"gif_font.c",
|
||||
"gif_hash.c",
|
||||
"gif_hash.h",
|
||||
"gif_lib_private.h",
|
||||
"gifalloc.c",
|
||||
"openbsd-reallocarray.c",
|
||||
"quantize.c",
|
||||
],
|
||||
hdrs = ["gif_lib.h"],
|
||||
includes = ["."],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":windows": [":windows_polyfill"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "windows_polyfill",
|
||||
hdrs = ["windows/unistd.h"],
|
||||
includes = ["windows"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "windows_unistd_h",
|
||||
outs = ["windows/unistd.h"],
|
||||
cmd = "touch $@",
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "windows",
|
||||
values = {"cpu": "x64_windows_msvc"},
|
||||
)
|
||||
|
26
gmock.BUILD
26
gmock.BUILD
@ -1,19 +1,25 @@
|
||||
# Description:
|
||||
# Google C++ Mocking Framework, a library for creating and using C++
|
||||
# mock classes.
|
||||
|
||||
licenses(["notice"]) # 3-clause BSD
|
||||
|
||||
cc_library(
|
||||
name = "gtest",
|
||||
srcs = [
|
||||
"gmock-1.7.0/gtest/src/gtest-all.cc",
|
||||
"gmock-1.7.0/src/gmock-all.cc",
|
||||
"gtest/src/gtest-all.cc",
|
||||
"src/gmock-all.cc",
|
||||
],
|
||||
hdrs = glob([
|
||||
"gmock-1.7.0/**/*.h",
|
||||
"gmock-1.7.0/gtest/src/*.cc",
|
||||
"gmock-1.7.0/src/*.cc",
|
||||
"**/*.h",
|
||||
"gtest/src/*.cc",
|
||||
"src/*.cc",
|
||||
]),
|
||||
includes = [
|
||||
"gmock-1.7.0",
|
||||
"gmock-1.7.0/gtest",
|
||||
"gmock-1.7.0/gtest/include",
|
||||
"gmock-1.7.0/include",
|
||||
".",
|
||||
"gtest",
|
||||
"gtest/include",
|
||||
"include",
|
||||
],
|
||||
linkopts = ["-pthread"],
|
||||
visibility = ["//visibility:public"],
|
||||
@ -21,7 +27,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "gtest_main",
|
||||
srcs = ["gmock-1.7.0/src/gmock_main.cc"],
|
||||
srcs = ["src/gmock_main.cc"],
|
||||
linkopts = ["-pthread"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":gtest"],
|
||||
|
32
grpc.BUILD
32
grpc.BUILD
@ -3,6 +3,7 @@
|
||||
# ...with small modifications to fix the build rules for :grpc++_unsecure.
|
||||
#
|
||||
# TODO(mrry): Upstream these fixes back to the gRPC repository.
|
||||
# TODO(jart): Fix nanopb's BUILD file. Fix grpc BUILD file.
|
||||
|
||||
# GRPC Bazel BUILD file.
|
||||
# This currently builds C, C++ and Objective-C code.
|
||||
@ -44,9 +45,26 @@ licenses(["notice"]) # 3-clause BSD
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
genrule(
|
||||
name = "pb_h",
|
||||
outs = ["third_party/nanopb/pb.h"],
|
||||
cmd = "echo '#include <pb.h>' >$@",
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "pb_decode_h",
|
||||
outs = ["third_party/nanopb/pb_decode.h"],
|
||||
cmd = "echo '#include <pb_decode.h>' >$@",
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
|
||||
genrule(
|
||||
name = "pb_encode_h",
|
||||
outs = ["third_party/nanopb/pb_encode.h"],
|
||||
cmd = "echo '#include <pb_encode.h>' >$@",
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpr",
|
||||
@ -499,6 +517,9 @@ cc_library(
|
||||
"src/core/ext/census/placeholders.c",
|
||||
"src/core/ext/census/tracing.c",
|
||||
"src/core/plugin_registry/grpc_plugin_registry.c",
|
||||
"third_party/nanopb/pb.h",
|
||||
"third_party/nanopb/pb_decode.h",
|
||||
"third_party/nanopb/pb_encode.h",
|
||||
],
|
||||
hdrs = [
|
||||
"include/grpc/byte_buffer.h",
|
||||
@ -856,6 +877,9 @@ cc_library(
|
||||
"src/core/lib/tsi/ssl_transport_security.c",
|
||||
"src/core/lib/tsi/transport_security.c",
|
||||
"src/core/plugin_registry/grpc_cronet_plugin_registry.c",
|
||||
"third_party/nanopb/pb.h",
|
||||
"third_party/nanopb/pb_decode.h",
|
||||
"third_party/nanopb/pb_encode.h",
|
||||
],
|
||||
hdrs = [
|
||||
"include/grpc/byte_buffer.h",
|
||||
@ -1185,6 +1209,9 @@ cc_library(
|
||||
"src/core/ext/census/placeholders.c",
|
||||
"src/core/ext/census/tracing.c",
|
||||
"src/core/plugin_registry/grpc_unsecure_plugin_registry.c",
|
||||
"third_party/nanopb/pb.h",
|
||||
"third_party/nanopb/pb_decode.h",
|
||||
"third_party/nanopb/pb_encode.h",
|
||||
],
|
||||
hdrs = [
|
||||
"include/grpc/byte_buffer.h",
|
||||
@ -2313,6 +2340,9 @@ objc_library(
|
||||
"src/core/ext/census/grpc_filter.h",
|
||||
"src/core/ext/census/mlog.h",
|
||||
"src/core/ext/census/rpc_metric_id.h",
|
||||
"third_party/nanopb/pb.h",
|
||||
"third_party/nanopb/pb_decode.h",
|
||||
"third_party/nanopb/pb_encode.h",
|
||||
],
|
||||
includes = [
|
||||
"include",
|
||||
|
160
jpeg.BUILD
160
jpeg.BUILD
@ -1,83 +1,89 @@
|
||||
SOURCES = [
|
||||
"jaricom.c",
|
||||
"jcapimin.c",
|
||||
"jcapistd.c",
|
||||
"jcarith.c",
|
||||
"jccoefct.c",
|
||||
"jccolor.c",
|
||||
"jcdctmgr.c",
|
||||
"jchuff.c",
|
||||
"jcinit.c",
|
||||
"jcmainct.c",
|
||||
"jcmarker.c",
|
||||
"jcmaster.c",
|
||||
"jcomapi.c",
|
||||
"jcparam.c",
|
||||
"jcprepct.c",
|
||||
"jcsample.c",
|
||||
"jctrans.c",
|
||||
"jdarith.c",
|
||||
"jdapimin.c",
|
||||
"jdapistd.c",
|
||||
"jdatadst.c",
|
||||
"jdatasrc.c",
|
||||
"jdcoefct.c",
|
||||
"jdcolor.c",
|
||||
"jddctmgr.c",
|
||||
"jdhuff.c",
|
||||
"jdinput.c",
|
||||
"jdmainct.c",
|
||||
"jdmarker.c",
|
||||
"jdmaster.c",
|
||||
"jdmerge.c",
|
||||
"jdpostct.c",
|
||||
"jdsample.c",
|
||||
"jdtrans.c",
|
||||
"jerror.c",
|
||||
"jfdctflt.c",
|
||||
"jfdctfst.c",
|
||||
"jfdctint.c",
|
||||
"jidctflt.c",
|
||||
"jidctfst.c",
|
||||
"jidctint.c",
|
||||
"jmemmgr.c",
|
||||
"jmemnobs.c",
|
||||
"jquant1.c",
|
||||
"jquant2.c",
|
||||
"jutils.c",
|
||||
]
|
||||
# Description:
|
||||
# The Independent JPEG Group's JPEG runtime library.
|
||||
|
||||
HEADERS = [
|
||||
"cderror.h",
|
||||
"cdjpeg.h",
|
||||
"jconfig.h",
|
||||
"jdct.h",
|
||||
"jerror.h",
|
||||
"jinclude.h",
|
||||
"jmemsys.h",
|
||||
"jmorecfg.h",
|
||||
"jpegint.h",
|
||||
"jpeglib.h",
|
||||
"jversion.h",
|
||||
"transupp.h",
|
||||
]
|
||||
|
||||
prefix_dir = "jpeg-9a"
|
||||
|
||||
genrule(
|
||||
name = "configure",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [prefix_dir + "/jconfig.h"],
|
||||
),
|
||||
outs = [prefix_dir + "/jconfig.h"],
|
||||
cmd = "pushd external/jpeg_archive/%s; workdir=$$(mktemp -d -t tmp.XXXXXXXXXX); cp -a * $$workdir; pushd $$workdir; ./configure; popd; popd; cp $$workdir/jconfig.h $(@D); rm -rf $$workdir;" % prefix_dir,
|
||||
)
|
||||
licenses(["notice"]) # custom notice-style license, see LICENSE
|
||||
|
||||
cc_library(
|
||||
name = "jpeg",
|
||||
srcs = [prefix_dir + "/" + source for source in SOURCES],
|
||||
hdrs = glob(["**/*.h"]) + [":configure"],
|
||||
includes = [prefix_dir],
|
||||
srcs = [
|
||||
"cderror.h",
|
||||
"cdjpeg.h",
|
||||
"jaricom.c",
|
||||
"jcapimin.c",
|
||||
"jcapistd.c",
|
||||
"jcarith.c",
|
||||
"jccoefct.c",
|
||||
"jccolor.c",
|
||||
"jcdctmgr.c",
|
||||
"jchuff.c",
|
||||
"jcinit.c",
|
||||
"jcmainct.c",
|
||||
"jcmarker.c",
|
||||
"jcmaster.c",
|
||||
"jcomapi.c",
|
||||
"jconfig.h",
|
||||
"jcparam.c",
|
||||
"jcprepct.c",
|
||||
"jcsample.c",
|
||||
"jctrans.c",
|
||||
"jdapimin.c",
|
||||
"jdapistd.c",
|
||||
"jdarith.c",
|
||||
"jdatadst.c",
|
||||
"jdatasrc.c",
|
||||
"jdcoefct.c",
|
||||
"jdcolor.c",
|
||||
"jdct.h",
|
||||
"jddctmgr.c",
|
||||
"jdhuff.c",
|
||||
"jdinput.c",
|
||||
"jdmainct.c",
|
||||
"jdmarker.c",
|
||||
"jdmaster.c",
|
||||
"jdmerge.c",
|
||||
"jdpostct.c",
|
||||
"jdsample.c",
|
||||
"jdtrans.c",
|
||||
"jerror.c",
|
||||
"jfdctflt.c",
|
||||
"jfdctfst.c",
|
||||
"jfdctint.c",
|
||||
"jidctflt.c",
|
||||
"jidctfst.c",
|
||||
"jidctint.c",
|
||||
"jinclude.h",
|
||||
"jmemmgr.c",
|
||||
"jmemnobs.c",
|
||||
"jmemsys.h",
|
||||
"jmorecfg.h",
|
||||
"jquant1.c",
|
||||
"jquant2.c",
|
||||
"jutils.c",
|
||||
"jversion.h",
|
||||
"transupp.h",
|
||||
],
|
||||
hdrs = [
|
||||
"jerror.h",
|
||||
"jpegint.h",
|
||||
"jpeglib.h",
|
||||
],
|
||||
includes = ["."],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "configure",
|
||||
outs = ["jconfig.h"],
|
||||
cmd = "cat <<EOF >$@\n" +
|
||||
"#define HAVE_PROTOTYPES 1\n" +
|
||||
"#define HAVE_UNSIGNED_CHAR 1\n" +
|
||||
"#define HAVE_UNSIGNED_SHORT 1\n" +
|
||||
"#define HAVE_STDDEF_H 1\n" +
|
||||
"#define HAVE_STDLIB_H 1\n" +
|
||||
"#ifdef WIN32\n" +
|
||||
"#define INLINE __inline\n" +
|
||||
"#else\n" +
|
||||
"#define INLINE __inline__\n" +
|
||||
"#endif\n" +
|
||||
"EOF\n",
|
||||
)
|
||||
|
@ -1,34 +1,31 @@
|
||||
licenses(["notice"]) # MIT
|
||||
|
||||
JSON_HEADERS = [
|
||||
"include/json/assertions.h",
|
||||
"include/json/autolink.h",
|
||||
"include/json/config.h",
|
||||
"include/json/features.h",
|
||||
"include/json/forwards.h",
|
||||
"include/json/json.h",
|
||||
"src/lib_json/json_batchallocator.h",
|
||||
"include/json/reader.h",
|
||||
"include/json/value.h",
|
||||
"include/json/writer.h",
|
||||
]
|
||||
|
||||
JSON_SOURCES = [
|
||||
"src/lib_json/json_reader.cpp",
|
||||
"src/lib_json/json_value.cpp",
|
||||
"src/lib_json/json_writer.cpp",
|
||||
"src/lib_json/json_tool.h",
|
||||
]
|
||||
|
||||
INLINE_SOURCES = [
|
||||
"src/lib_json/json_valueiterator.inl",
|
||||
]
|
||||
licenses(["unencumbered"]) # Public Domain or MIT
|
||||
|
||||
cc_library(
|
||||
name = "jsoncpp",
|
||||
srcs = JSON_SOURCES,
|
||||
hdrs = JSON_HEADERS,
|
||||
srcs = [
|
||||
"include/json/assertions.h",
|
||||
"src/lib_json/json_batchallocator.h",
|
||||
"src/lib_json/json_reader.cpp",
|
||||
"src/lib_json/json_tool.h",
|
||||
"src/lib_json/json_value.cpp",
|
||||
"src/lib_json/json_writer.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/json/autolink.h",
|
||||
"include/json/config.h",
|
||||
"include/json/features.h",
|
||||
"include/json/forwards.h",
|
||||
"include/json/json.h",
|
||||
"include/json/reader.h",
|
||||
"include/json/value.h",
|
||||
"include/json/writer.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
textual_hdrs = INLINE_SOURCES,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":private"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "private",
|
||||
textual_hdrs = ["src/lib_json/json_valueiterator.inl"],
|
||||
)
|
||||
|
13
linenoise.BUILD
Normal file
13
linenoise.BUILD
Normal file
@ -0,0 +1,13 @@
|
||||
licenses(["notice"]) # 2-clause BSD
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "linenoise",
|
||||
srcs = ["linenoise.c"],
|
||||
hdrs = ["linenoise.h"],
|
||||
)
|
28
nanopb.BUILD
28
nanopb.BUILD
@ -1,19 +1,21 @@
|
||||
SOURCES = [
|
||||
"pb_common.c",
|
||||
"pb_decode.c",
|
||||
"pb_encode.c",
|
||||
]
|
||||
# Description:
|
||||
# Nanopb, a tiny ANSI C protobuf implementation for use on embedded devices.
|
||||
|
||||
HEADERS = [
|
||||
"pb.h",
|
||||
"pb_common.h",
|
||||
"pb_decode.h",
|
||||
"pb_encode.h",
|
||||
]
|
||||
licenses(["notice"]) # zlib license
|
||||
|
||||
cc_library(
|
||||
name = "nanopb",
|
||||
srcs = SOURCES,
|
||||
hdrs = HEADERS,
|
||||
srcs = [
|
||||
"pb_common.c",
|
||||
"pb_decode.c",
|
||||
"pb_encode.c",
|
||||
],
|
||||
hdrs = [
|
||||
"pb.h",
|
||||
"pb_common.h",
|
||||
"pb_decode.h",
|
||||
"pb_encode.h",
|
||||
],
|
||||
includes = ["."],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
61
png.BUILD
61
png.BUILD
@ -1,40 +1,33 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
# Description:
|
||||
# libpng is the official PNG reference library.
|
||||
|
||||
prefix_dir = "libpng-1.2.53"
|
||||
|
||||
PNG_SOURCES = [
|
||||
"png.c",
|
||||
"pngerror.c",
|
||||
"pngget.c",
|
||||
"pngmem.c",
|
||||
"pngpread.c",
|
||||
"pngread.c",
|
||||
"pngrio.c",
|
||||
"pngrtran.c",
|
||||
"pngrutil.c",
|
||||
"pngset.c",
|
||||
"pngtrans.c",
|
||||
"pngwio.c",
|
||||
"pngwrite.c",
|
||||
"pngwtran.c",
|
||||
"pngwutil.c",
|
||||
]
|
||||
|
||||
genrule(
|
||||
name = "configure",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [prefix_dir + "/config.h"],
|
||||
),
|
||||
outs = [prefix_dir + "/config.h"],
|
||||
cmd = "pushd external/png_archive/%s; workdir=$$(mktemp -d -t tmp.XXXXXXXXXX); cp -a * $$workdir; pushd $$workdir; ./configure --enable-shared=no --with-pic=no; popd; popd; cp $$workdir/config.h $(@D); rm -rf $$workdir;" % prefix_dir,
|
||||
)
|
||||
licenses(["notice"]) # BSD/MIT-like license
|
||||
|
||||
cc_library(
|
||||
name = "png",
|
||||
srcs = [prefix_dir + "/" + source for source in PNG_SOURCES],
|
||||
hdrs = glob(["**/*.h"]) + [":configure"],
|
||||
includes = [prefix_dir],
|
||||
linkopts = ["-lz"],
|
||||
srcs = [
|
||||
"png.c",
|
||||
"pngerror.c",
|
||||
"pngget.c",
|
||||
"pngmem.c",
|
||||
"pngpread.c",
|
||||
"pngread.c",
|
||||
"pngrio.c",
|
||||
"pngrtran.c",
|
||||
"pngrutil.c",
|
||||
"pngset.c",
|
||||
"pngtrans.c",
|
||||
"pngwio.c",
|
||||
"pngwrite.c",
|
||||
"pngwtran.c",
|
||||
"pngwutil.c",
|
||||
],
|
||||
hdrs = [
|
||||
"png.h",
|
||||
"pngconf.h",
|
||||
],
|
||||
includes = ["."],
|
||||
linkopts = ["-lm"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["@zlib_archive//:zlib"],
|
||||
)
|
||||
|
13
six.BUILD
13
six.BUILD
@ -1,13 +1,12 @@
|
||||
genrule(
|
||||
name = "copy_six",
|
||||
srcs = ["six-1.10.0/six.py"],
|
||||
outs = ["six.py"],
|
||||
cmd = "cp $< $(@)",
|
||||
)
|
||||
# Description:
|
||||
# Six provides simple utilities for wrapping over differences between Python 2
|
||||
# and Python 3.
|
||||
|
||||
licenses(["notice"]) # MIT
|
||||
|
||||
py_library(
|
||||
name = "six",
|
||||
srcs = ["six.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
@ -93,10 +93,12 @@ filegroup(
|
||||
":all_files",
|
||||
"//tensorflow/c:all_files",
|
||||
"//tensorflow/cc:all_files",
|
||||
"//tensorflow/cc/saved_model:all_files",
|
||||
"//tensorflow/contrib:all_files",
|
||||
"//tensorflow/contrib/android:all_files",
|
||||
"//tensorflow/contrib/bayesflow:all_files",
|
||||
"//tensorflow/contrib/copy_graph:all_files",
|
||||
"//tensorflow/contrib/crf:all_files",
|
||||
"//tensorflow/contrib/cudnn_rnn:all_files",
|
||||
"//tensorflow/contrib/distributions:all_files",
|
||||
"//tensorflow/contrib/factorization:all_files",
|
||||
@ -129,7 +131,11 @@ filegroup(
|
||||
"//tensorflow/contrib/slim/python/slim/nets:all_files",
|
||||
"//tensorflow/contrib/tensor_forest:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
|
||||
"//tensorflow/contrib/tensorboard:all_files",
|
||||
"//tensorflow/contrib/testing:all_files",
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof/internal:all_files",
|
||||
"//tensorflow/contrib/training:all_files",
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
"//tensorflow/core:all_files",
|
||||
@ -142,6 +148,7 @@ filegroup(
|
||||
"//tensorflow/core/platform/default/build_config:all_files",
|
||||
"//tensorflow/core/platform/hadoop:all_files",
|
||||
"//tensorflow/core/util/ctc:all_files",
|
||||
"//tensorflow/core/util/tensor_bundle:all_files",
|
||||
"//tensorflow/examples/android:all_files",
|
||||
"//tensorflow/examples/how_tos/reading_data:all_files",
|
||||
"//tensorflow/examples/image_retraining:all_files",
|
||||
@ -166,6 +173,7 @@ filegroup(
|
||||
"//tensorflow/python/debug:all_files",
|
||||
"//tensorflow/python/kernel_tests:all_files",
|
||||
"//tensorflow/python/saved_model:all_files",
|
||||
"//tensorflow/python/saved_model/example:all_files",
|
||||
"//tensorflow/python/tools:all_files",
|
||||
"//tensorflow/tensorboard:all_files",
|
||||
"//tensorflow/tensorboard/app:all_files",
|
||||
@ -176,7 +184,6 @@ filegroup(
|
||||
"//tensorflow/tensorboard/lib:all_files",
|
||||
"//tensorflow/tensorboard/lib/python:all_files",
|
||||
"//tensorflow/tensorboard/scripts:all_files",
|
||||
"//tensorflow/third_party/hadoop:all_files",
|
||||
"//tensorflow/tools/dist_test/server:all_files",
|
||||
"//tensorflow/tools/docker:all_files",
|
||||
"//tensorflow/tools/docker/notebooks:all_files",
|
||||
@ -185,6 +192,7 @@ filegroup(
|
||||
"//tensorflow/tools/proto_text:all_files",
|
||||
"//tensorflow/tools/test:all_files",
|
||||
"//tensorflow/user_ops:all_files",
|
||||
"//third_party/hadoop:all_files",
|
||||
],
|
||||
visibility = [":__subpackages__"],
|
||||
)
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -1574,6 +1575,40 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
|
||||
status->status = MessageToBuffer(def, output_graph_def);
|
||||
}
|
||||
|
||||
struct TF_ImportGraphDefOptions {
|
||||
tensorflow::ImportGraphDefOptions opts;
|
||||
};
|
||||
|
||||
TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
|
||||
return new TF_ImportGraphDefOptions;
|
||||
}
|
||||
void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
|
||||
delete opts;
|
||||
}
|
||||
void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
|
||||
const char* prefix) {
|
||||
opts->opts.prefix = prefix;
|
||||
}
|
||||
|
||||
void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* opts,
|
||||
TF_Status* status) {
|
||||
GraphDef def;
|
||||
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
|
||||
status->status = InvalidArgument("Invalid GraphDef");
|
||||
return;
|
||||
}
|
||||
mutex_lock l(graph->mu);
|
||||
const int last_node_id = graph->graph.num_node_ids();
|
||||
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
|
||||
&graph->refiner);
|
||||
if (!status->status.ok()) return;
|
||||
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
|
||||
auto* node = graph->graph.FindNodeId(i);
|
||||
if (node != nullptr) graph->name_map[node->name()] = node;
|
||||
}
|
||||
}
|
||||
|
||||
// TF_SessionWithGraph functions ----------------------------------------------
|
||||
|
||||
TF_SessionWithGraph* TF_NewSessionWithGraph(TF_Graph* graph,
|
||||
|
@ -739,20 +739,36 @@ extern TF_Operation* TF_GraphOperationByName(TF_Graph* graph,
|
||||
// }
|
||||
extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos);
|
||||
|
||||
// Note: The following two functions may fail on very large protos in the
|
||||
// future.
|
||||
|
||||
// Write out a serialized representation of `graph` (as a GraphDef protocol
|
||||
// message) to `output_graph_def`.
|
||||
//
|
||||
// May fail on very large graphs in the future.
|
||||
extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
|
||||
TF_Status* status);
|
||||
|
||||
// TF_ImportGraphDefOptions holds options that can be passed to
|
||||
// TF_GraphImportGraphDef.
|
||||
typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions;
|
||||
|
||||
extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions();
|
||||
extern void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts);
|
||||
|
||||
// Set the prefix to be prepended to the names of nodes in `graph_def` that will
|
||||
// be imported into `graph`.
|
||||
extern void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
|
||||
const char* prefix);
|
||||
|
||||
// Import the graph serialized in `graph_def` into `graph`.
|
||||
extern void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options,
|
||||
TF_Status* status);
|
||||
|
||||
// Note: The following function may fail on very large protos in the future.
|
||||
|
||||
extern void TF_OperationToNodeDef(TF_Operation* oper,
|
||||
TF_Buffer* output_node_def,
|
||||
TF_Status* status);
|
||||
|
||||
// TODO(cwhipkey): Query shape for operation outputs.
|
||||
|
||||
// TODO(ashankar): Import GraphDef into TF_Graph.
|
||||
|
||||
// TODO(andydavis): Function to add gradients to a graph.
|
||||
|
||||
// TODO(josh11b): Register OpDef, available to all operations added
|
||||
|
@ -45,7 +45,7 @@ TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src);
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(CApi, Status) {
|
||||
TEST(CAPI, Status) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||
EXPECT_EQ(string(), TF_Message(s));
|
||||
@ -60,7 +60,7 @@ static void Deallocator(void* data, size_t, void* arg) {
|
||||
*reinterpret_cast<bool*>(arg) = true;
|
||||
}
|
||||
|
||||
TEST(CApi, Tensor) {
|
||||
TEST(CAPI, Tensor) {
|
||||
const int num_bytes = 6 * sizeof(float);
|
||||
float* values =
|
||||
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
|
||||
@ -80,7 +80,7 @@ TEST(CApi, Tensor) {
|
||||
EXPECT_TRUE(deallocator_called);
|
||||
}
|
||||
|
||||
TEST(CApi, AllocateTensor) {
|
||||
TEST(CAPI, AllocateTensor) {
|
||||
const int num_bytes = 6 * sizeof(float);
|
||||
int64_t dims[] = {2, 3};
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, num_bytes);
|
||||
@ -92,7 +92,7 @@ TEST(CApi, AllocateTensor) {
|
||||
TF_DeleteTensor(t);
|
||||
}
|
||||
|
||||
TEST(CApi, LibraryLoadFunctions) {
|
||||
TEST(CAPI, LibraryLoadFunctions) {
|
||||
// Load the library.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
@ -139,7 +139,7 @@ static void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CApi, TensorEncodeDecodeStrings) {
|
||||
TEST(CAPI, TensorEncodeDecodeStrings) {
|
||||
TestEncodeDecode(__LINE__, {});
|
||||
TestEncodeDecode(__LINE__, {"hello"});
|
||||
TestEncodeDecode(__LINE__,
|
||||
@ -149,12 +149,12 @@ TEST(CApi, TensorEncodeDecodeStrings) {
|
||||
TestEncodeDecode(__LINE__, {"small", big, "small2"});
|
||||
}
|
||||
|
||||
TEST(CApi, SessionOptions) {
|
||||
TEST(CAPI, SessionOptions) {
|
||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||
TF_DeleteSessionOptions(opt);
|
||||
}
|
||||
|
||||
TEST(CApi, SessionWithRunMetadata) {
|
||||
TEST(CAPI, SessionWithRunMetadata) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_SessionOptions* opt = TF_NewSessionOptions();
|
||||
TF_Session* session = TF_NewSession(opt, s);
|
||||
@ -230,7 +230,7 @@ TEST(CAPI, StatusEnum) {
|
||||
EXPECT_EQ(TF_DATA_LOSS, static_cast<TF_Code>(tensorflow::error::DATA_LOSS));
|
||||
}
|
||||
|
||||
TEST(CApi, GetAllOpList) {
|
||||
TEST(CAPI, GetAllOpList) {
|
||||
TF_Buffer* buf = TF_GetAllOpList();
|
||||
tensorflow::OpList op_list;
|
||||
EXPECT_TRUE(op_list.ParseFromArray(buf->data, buf->length));
|
||||
@ -646,6 +646,47 @@ TEST(CAPI, Graph) {
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
TEST(CAPI, ImportGraphDef) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
// Create a graph with two nodes: x and 3
|
||||
Placeholder(graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
|
||||
ScalarConst(3, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
|
||||
|
||||
// Export to a GraphDef
|
||||
TF_Buffer* graph_def = TF_NewBuffer();
|
||||
TF_GraphToGraphDef(graph, graph_def, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Import it again, with a prefix, in a fresh graph.
|
||||
TF_DeleteGraph(graph);
|
||||
graph = TF_NewGraph();
|
||||
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
|
||||
TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
|
||||
TF_GraphImportGraphDef(graph, graph_def, opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
TF_DeleteImportGraphDefOptions(opts);
|
||||
TF_DeleteBuffer(graph_def);
|
||||
|
||||
TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar");
|
||||
TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed");
|
||||
ASSERT_TRUE(scalar != nullptr);
|
||||
ASSERT_TRUE(feed != nullptr);
|
||||
|
||||
// Can add nodes to the imported graph without trouble.
|
||||
Add(feed, scalar, graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
TF_DeleteGraph(graph);
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
class CSessionWithGraph {
|
||||
public:
|
||||
CSessionWithGraph(TF_Graph* graph, TF_Status* s) {
|
||||
|
@ -48,6 +48,42 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradient_checker",
|
||||
srcs = ["framework/gradient_checker.cc"],
|
||||
hdrs = ["framework/gradient_checker.h"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":client_session",
|
||||
":grad_op_registry",
|
||||
":gradients",
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "framework_gradient_checker_test",
|
||||
srcs = ["framework/gradient_checker_test.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":grad_ops",
|
||||
":gradient_checker",
|
||||
":testutil",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grad_ops",
|
||||
deps = [
|
||||
|
165
tensorflow/cc/framework/gradient_checker.cc
Normal file
165
tensorflow/cc/framework/gradient_checker.cc
Normal file
@ -0,0 +1,165 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(andydavis) Support returning relative error (as opposed to max error)
|
||||
// between theoretical and numerical jacobians:
|
||||
// fabs(jac_t - jac_n) / max(fabs(jac_t), fabs(jac_n))
|
||||
|
||||
// TODO(andydavis) Vectorize and/or multi-thread Jacobian computations if
|
||||
// performance becomes an issue.
|
||||
|
||||
template <typename T>
|
||||
Status ComputeTheoreticalJacobianTranspose(
|
||||
const Scope& scope, const ops::Output& x, const TensorShape& x_shape,
|
||||
const Tensor& x_data, const ops::Output& y, const TensorShape& y_shape,
|
||||
Tensor* jacobian_t) {
|
||||
// Call AddSymbolicGradients to get 'dx' (we will feed 'dy').
|
||||
auto dy = Cast(scope, Const(scope, 1.0, y_shape), x.type());
|
||||
std::vector<ops::Output> outputs;
|
||||
TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, {y}, {x}, {dy}, &outputs));
|
||||
auto dx = outputs[0];
|
||||
|
||||
// Initialize 'dy_data' to zeros.
|
||||
Tensor dy_data(y.type(), y_shape);
|
||||
auto dy_data_flat = dy_data.flat<T>();
|
||||
dy_data_flat.setZero();
|
||||
|
||||
// Compute the theoretical Jacobian one row at a time by backproping '1.0'
|
||||
// for each element of 'dy', while holding all other elements of 'dy' at zero.
|
||||
ClientSession session(scope);
|
||||
std::vector<Tensor> dxout;
|
||||
const int64 x_size = x_shape.num_elements();
|
||||
const int64 dy_size = y_shape.num_elements();
|
||||
auto jacobian = jacobian_t->matrix<T>();
|
||||
for (int c = 0; c < dy_size; ++c) {
|
||||
dy_data_flat(c) = 1.0;
|
||||
|
||||
TF_RETURN_IF_ERROR(session.Run({{x, x_data}, {dy, dy_data}}, {dx}, &dxout));
|
||||
|
||||
auto dx_flat = dxout[0].flat<T>();
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
jacobian(r, c) = dx_flat(r);
|
||||
}
|
||||
|
||||
dy_data_flat(c) = 0.0;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ComputeNumericJacobianTranspose(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape,
|
||||
const ops::Output& y,
|
||||
const TensorShape& y_shape,
|
||||
const T delta, Tensor* x_data,
|
||||
Tensor* jacobian_t) {
|
||||
const int64 x_size = x_shape.num_elements();
|
||||
const int64 y_size = y_shape.num_elements();
|
||||
auto x_data_flat = x_data->flat<T>();
|
||||
|
||||
// Compute the numeric Jacobian one column at a time by perturbing each
|
||||
// element of 'x_data' (positively and negatively) by 'delta', and
|
||||
// updating the jacobian with the centered difference.
|
||||
ClientSession session(scope);
|
||||
std::vector<Tensor> yout;
|
||||
auto jacobian = jacobian_t->matrix<T>();
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
// Store current value of 'x' at 'r'.
|
||||
T v = x_data_flat(r);
|
||||
// Evaluate at positive delta.
|
||||
x_data_flat(r) = v + delta;
|
||||
TF_RETURN_IF_ERROR(session.Run({{x, *x_data}}, {y}, &yout));
|
||||
Tensor y_pos = yout[0];
|
||||
// Evaluate at negative delta.
|
||||
x_data_flat(r) = v - delta;
|
||||
TF_RETURN_IF_ERROR(session.Run({{x, *x_data}}, {y}, &yout));
|
||||
Tensor y_neg = yout[0];
|
||||
// Compute element-wise centered difference and store in Jacobian.
|
||||
auto y_pos_flat = y_pos.flat<T>();
|
||||
auto y_neg_flat = y_neg.flat<T>();
|
||||
const T scale = 2 * delta;
|
||||
for (int c = 0; c < y_size; ++c) {
|
||||
jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale;
|
||||
}
|
||||
// Restore pre-perturbation value.
|
||||
x_data_flat(r) = v;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error) {
|
||||
const int64 x_size = x_shape.num_elements();
|
||||
const int64 y_size = y_shape.num_elements();
|
||||
|
||||
// Initialize 'x_data' to random values.
|
||||
Tensor x_data(x.type(), x_shape);
|
||||
auto x_data_flat = x_data.flat<T>();
|
||||
x_data_flat.setRandom();
|
||||
|
||||
// Initialize theoretical Jacobian to zeros.
|
||||
Tensor jacobian_t(x.type(), {x_size, y_size});
|
||||
auto jacobian_t_flat = jacobian_t.flat<T>();
|
||||
jacobian_t_flat.setZero();
|
||||
|
||||
// Compute theoretical Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
|
||||
scope, x, x_shape, x_data, y, y_shape, &jacobian_t));
|
||||
|
||||
// Inititalize numeric Jacobian to zeros.
|
||||
Tensor jacobian_n(x.type(), {x_size, y_size});
|
||||
auto jacobian_n_flat = jacobian_n.flat<T>();
|
||||
jacobian_n_flat.setZero();
|
||||
|
||||
// Compute numeric Jacobian.
|
||||
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
|
||||
scope, x, x_shape, y, y_shape, 1e-3, &x_data, &jacobian_n));
|
||||
|
||||
// Compute the maximum error between theoretical and numeric Jacobians.
|
||||
*max_error = 0.0;
|
||||
auto jac_t = jacobian_t.matrix<T>();
|
||||
auto jac_n = jacobian_n.matrix<T>();
|
||||
for (int r = 0; r < x_size; ++r) {
|
||||
for (int c = 0; c < y_size; ++c) {
|
||||
*max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
|
||||
template Status ComputeGradientError<T>( \
|
||||
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
|
||||
const ops::Output& y, const TensorShape& y_shape, T* max_error)
|
||||
|
||||
INSTANTIATE_GRAD_ERR_TYPE(float);
|
||||
INSTANTIATE_GRAD_ERR_TYPE(double);
|
||||
|
||||
} // namespace tensorflow
|
35
tensorflow/cc/framework/gradient_checker.h
Normal file
35
tensorflow/cc/framework/gradient_checker.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns in 'max_error' the maximum element-wise error for dy/dx between the
|
||||
// computed and numeric Jacobian matrices where 'x' and 'y' are tensors.
|
||||
// This function adds operations to the graph associated with 'scope'.
|
||||
template <typename T>
|
||||
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
|
||||
const TensorShape& x_shape, const ops::Output& y,
|
||||
const TensorShape& y_shape, T* max_error);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
|
70
tensorflow/cc/framework/gradient_checker_test.cc
Normal file
70
tensorflow/cc/framework/gradient_checker_test.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(GradientCheckerTest, BasicFloat) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
TensorShape shape({2, 4, 3});
|
||||
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto y = Square(scope, x);
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError<float>(scope, x, shape, y, shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, BasicDouble) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
TensorShape shape({2, 4, 3});
|
||||
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(shape));
|
||||
auto y = Square(scope, x);
|
||||
double max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError<double>(scope, x, shape, y, shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, MatMulGrad) {
|
||||
Scope scope = Scope::NewRootScope();
|
||||
|
||||
TensorShape x_shape({4, 3});
|
||||
TensorShape y_shape({3, 2});
|
||||
TensorShape z_shape({4, 2});
|
||||
|
||||
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(x_shape));
|
||||
auto y = Const(scope, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, y_shape);
|
||||
auto z = MatMul(scope, x, y);
|
||||
double max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError<double>(scope, x, x_shape, z, z_shape, &max_error));
|
||||
EXPECT_LT(max_error, 1e-10);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
68
tensorflow/cc/saved_model/BUILD
Normal file
68
tensorflow/cc/saved_model/BUILD
Normal file
@ -0,0 +1,68 @@
|
||||
# Description:
|
||||
# TensorFlow SavedModel.
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "constants",
|
||||
hdrs = ["constants.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loader",
|
||||
srcs = ["loader.cc"],
|
||||
hdrs = ["loader.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "loader_test",
|
||||
srcs = ["loader_test.cc"],
|
||||
data = [
|
||||
":saved_model_half_plus_two",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":constants",
|
||||
":loader",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "saved_model_half_plus_two",
|
||||
srcs = glob(["testdata/half_plus_two/*"]),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets.
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
36
tensorflow/cc/saved_model/constants.h
Normal file
36
tensorflow/cc/saved_model/constants.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// SavedModel proto filename.
|
||||
constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
|
||||
|
||||
// SavedModel text format proto filename.
|
||||
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
||||
|
||||
// SavedModel variables filename.
|
||||
constexpr char kSavedModelVariablesFilename[] = "saved_model_variables";
|
||||
|
||||
// Commonly used tags.
|
||||
constexpr char kSavedModelTagServe[] = "serve";
|
||||
constexpr char kSavedModelTagTrain[] = "train";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
126
tensorflow/cc/saved_model/loader.cc
Normal file
126
tensorflow/cc/saved_model/loader.cc
Normal file
@ -0,0 +1,126 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
||||
const string saved_model_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
return ReadBinaryProto(Env::Default(), saved_model_path, saved_model_proto);
|
||||
}
|
||||
|
||||
Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* meta_graph_def_to_load) {
|
||||
for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) {
|
||||
// Get tags from the meta_graph_def.
|
||||
std::unordered_set<string> graph_tags;
|
||||
for (const string& tag : meta_graph_def.meta_info_def().tags()) {
|
||||
graph_tags.insert(tag);
|
||||
}
|
||||
// Match with the set of tags provided.
|
||||
if (graph_tags == tags) {
|
||||
*meta_graph_def_to_load = meta_graph_def;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"Could not find meta graph def matching supplied tags.");
|
||||
}
|
||||
|
||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
const SessionOptions& session_options,
|
||||
std::unique_ptr<Session>* session) {
|
||||
session->reset(NewSession(session_options));
|
||||
return (*session)->Create(meta_graph_def.graph_def());
|
||||
}
|
||||
|
||||
Status Restore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece restore_op_name,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
Session* session) {
|
||||
const string variables_path =
|
||||
io::JoinPath(export_dir, kSavedModelVariablesFilename);
|
||||
if (!Env::Default()->FileExists(variables_path)) {
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"Could not find checkpointed variables.");
|
||||
}
|
||||
|
||||
// Add variables to the graph.
|
||||
Tensor variables_path_tensor(DT_STRING, TensorShape({}));
|
||||
variables_path_tensor.scalar<string>()() = variables_path;
|
||||
|
||||
std::vector<std::pair<string, Tensor>> inputs = {
|
||||
{variable_filename_const_op_name.ToString(), variables_path_tensor}};
|
||||
|
||||
RunMetadata run_metadata;
|
||||
return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
|
||||
nullptr /* outputs */, &run_metadata);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status LoadSavedModel(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
const SessionOptions& session_options,
|
||||
const RunOptions& run_options,
|
||||
SavedModelBundle* const bundle) {
|
||||
if (!MaybeSavedModelDirectory(export_dir)) {
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"SavedModel not found in export directory: " + export_dir);
|
||||
}
|
||||
LOG(INFO) << "Loading SavedModel from: " << export_dir;
|
||||
|
||||
SavedModel saved_model_proto;
|
||||
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def));
|
||||
|
||||
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
|
||||
bundle->meta_graph_def, session_options, &bundle->session));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
Restore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||
bundle->session.get()));
|
||||
|
||||
LOG(INFO) << "Done loading SavedModel.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool MaybeSavedModelDirectory(const string& export_dir) {
|
||||
const string saved_model_pb_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
const string saved_model_pbtxt_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
||||
return Env::Default()->FileExists(saved_model_pb_path) ||
|
||||
Env::Default()->FileExists(saved_model_pbtxt_path);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
55
tensorflow/cc/saved_model/loader.h
Normal file
55
tensorflow/cc/saved_model/loader.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// SavedModel loading functions and SavedModelBundle struct.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// SavedModel representation once the SavedModel is loaded from storage.
|
||||
struct SavedModelBundle {
|
||||
std::unique_ptr<Session> session;
|
||||
MetaGraphDef meta_graph_def;
|
||||
};
|
||||
|
||||
// Loads a SavedModel from the specified export directory. The meta graph def to
|
||||
// be loaded is identified by the supplied tags, corresponding exactly to the
|
||||
// set of tags used at SavedModel build time. Returns a SavedModel bundle with a
|
||||
// session and the requested meta graph def, if found.
|
||||
Status LoadSavedModel(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
const SessionOptions& session_options,
|
||||
const RunOptions& run_options,
|
||||
SavedModelBundle* const bundle);
|
||||
|
||||
// Checks whether the provided directory could contain a SavedModel. Note that
|
||||
// the method does not load any data by itself. If the method returns `false`,
|
||||
// the export directory definitely does not contain a SavedModel. If the method
|
||||
// returns `true`, the export directory may contain a SavedModel but provides no
|
||||
// guarantee that it can be loaded.
|
||||
bool MaybeSavedModelDirectory(const string& export_dir);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
|
129
tensorflow/cc/saved_model/loader_test.cc
Normal file
129
tensorflow/cc/saved_model/loader_test.cc
Normal file
@ -0,0 +1,129 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata/half_plus_two";
|
||||
|
||||
class LoaderTest : public ::testing::Test {
|
||||
protected:
|
||||
LoaderTest() {}
|
||||
|
||||
void CheckSavedModelBundle(const SavedModelBundle& bundle) {
|
||||
// Validate the half plus two behavior.
|
||||
Tensor input = test::AsTensor<float>({0, 1, 2, 3}, TensorShape({4, 1}));
|
||||
|
||||
// Retrieve the regression signature from meta graph def.
|
||||
const auto signature_def_map = bundle.meta_graph_def.signature_def();
|
||||
const auto signature_def = signature_def_map.at("regression");
|
||||
|
||||
const string input_name = signature_def.inputs().at("input").name();
|
||||
const string output_name = signature_def.outputs().at("output").name();
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(bundle.session->Run({{input_name, input}}, {output_name}, {},
|
||||
&outputs));
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
test::ExpectTensorEqual<float>(
|
||||
outputs[0],
|
||||
test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(LoaderTest, TagMatch) {
|
||||
SavedModelBundle bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData);
|
||||
TF_ASSERT_OK(LoadSavedModel(export_dir, {kSavedModelTagServe},
|
||||
session_options, run_options, &bundle));
|
||||
CheckSavedModelBundle(bundle);
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, NoTagMatch) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData);
|
||||
Status st = LoadSavedModel(export_dir, {"missing-tag"}, session_options,
|
||||
run_options, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(
|
||||
StringPiece(st.error_message())
|
||||
.contains("Could not find meta graph def matching supplied tags."))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, NoTagMatchMultiple) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData);
|
||||
Status st = LoadSavedModel(export_dir, {kSavedModelTagServe, "missing-tag"},
|
||||
session_options, run_options, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(
|
||||
StringPiece(st.error_message())
|
||||
.contains("Could not find meta graph def matching supplied tags."))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, InvalidExportPath) {
|
||||
SavedModelBundle bundle;
|
||||
RunOptions run_options;
|
||||
SessionOptions session_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
Status st = LoadSavedModel(export_dir, {kSavedModelTagServe}, session_options,
|
||||
run_options, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, MaybeSavedModelDirectory) {
|
||||
// Valid SavedModel directory.
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData);
|
||||
EXPECT_TRUE(MaybeSavedModelDirectory(export_dir));
|
||||
|
||||
// Directory that does not exist.
|
||||
const string missing_export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
EXPECT_FALSE(MaybeSavedModelDirectory(missing_export_dir));
|
||||
|
||||
// Directory that exists but is an invalid SavedModel location.
|
||||
const string invalid_export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata");
|
||||
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
2
tensorflow/cc/saved_model/testdata/half_plus_two/checkpoint
vendored
Normal file
2
tensorflow/cc/saved_model/testdata/half_plus_two/checkpoint
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
model_checkpoint_path: "/tmp/saved_model/half_plus_two/saved_model_variables"
|
||||
all_model_checkpoint_paths: "/tmp/saved_model/half_plus_two/saved_model_variables"
|
BIN
tensorflow/cc/saved_model/testdata/half_plus_two/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two/saved_model.pb
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two/saved_model_variables
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two/saved_model_variables
vendored
Normal file
Binary file not shown.
@ -15,6 +15,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/contrib/bayesflow:bayesflow_py",
|
||||
"//tensorflow/contrib/copy_graph:copy_graph_py",
|
||||
"//tensorflow/contrib/crf:crf_py",
|
||||
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/factorization:factorization_py",
|
||||
@ -35,6 +36,7 @@ py_library(
|
||||
"//tensorflow/contrib/slim:nets",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:ops_lib",
|
||||
"//tensorflow/contrib/tensorboard",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
# Add projects here, they will show up under tf.contrib.
|
||||
from tensorflow.contrib import bayesflow
|
||||
from tensorflow.contrib import copy_graph
|
||||
from tensorflow.contrib import crf
|
||||
from tensorflow.contrib import cudnn_rnn
|
||||
from tensorflow.contrib import distributions
|
||||
from tensorflow.contrib import factorization
|
||||
@ -38,6 +39,7 @@ from tensorflow.contrib import quantization
|
||||
from tensorflow.contrib import rnn
|
||||
from tensorflow.contrib import slim
|
||||
from tensorflow.contrib import tensor_forest
|
||||
from tensorflow.contrib import tensorboard
|
||||
from tensorflow.contrib import testing
|
||||
from tensorflow.contrib import training
|
||||
from tensorflow.contrib import util
|
||||
|
@ -55,7 +55,7 @@ Log E_q[ f(Z) p(Z) / q(Z) ]
|
||||
C := Max[ Log[f(Z)] + Log[p(Z)] - Log[q(Z)] ].
|
||||
```
|
||||
|
||||
The maximum value of the exponentiated term will be 0.0, and the the expecation
|
||||
The maximum value of the exponentiated term will be 0.0, and the the expectation
|
||||
can be evaluated in a stable manner.
|
||||
|
||||
## Ops
|
||||
@ -252,9 +252,7 @@ def expectation(f, p, z=None, n=None, seed=None, name='expectation'):
|
||||
User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
|
||||
|
||||
Args:
|
||||
f: Callable mapping samples from `sampling_dist_q` to `Tensors` with
|
||||
shape broadcastable to `q.batch_shape`.
|
||||
For example, `f` works "just like" `sampling_dist_q.log_prob`.
|
||||
f: Callable mapping samples from `p` to `Tensors`.
|
||||
p: `tf.contrib.distributions.BaseDistribution`.
|
||||
z: `Tensor` of samples from `p`, produced by `p.sample_n`.
|
||||
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
|
||||
@ -262,7 +260,36 @@ def expectation(f, p, z=None, n=None, seed=None, name='expectation'):
|
||||
name: A name to give this `Op`.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with same `dtype` as `p`, and shape equal to `p.batch_shape`.
|
||||
A `Tensor` with the same `dtype` as `p`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
N_samples = 10000
|
||||
|
||||
distributions = tf.contrib.distributions
|
||||
|
||||
dist = distributions.Uniform([0.0, 0.0], [1.0, 2.0])
|
||||
elementwise_mean = lambda x: x
|
||||
mean_sum = lambda x: tf.reduce_sum(x, 1)
|
||||
|
||||
estimate_elementwise_mean_tf = monte_carlo.expectation(elementwise_mean,
|
||||
dist,
|
||||
n=N_samples)
|
||||
estimate_mean_sum_tf = monte_carlo.expectation(mean_sum,
|
||||
dist,
|
||||
n=N_samples)
|
||||
|
||||
with tf.Session() as sess:
|
||||
estimate_elementwise_mean, estimate_mean_sum = (
|
||||
sess.run([estimate_elementwise_mean_tf, estimate_mean_sum_tf]))
|
||||
print estimate_elementwise_mean
|
||||
>>> np.array([ 0.50018013 1.00097895], dtype=np.float32)
|
||||
print estimate_mean_sum
|
||||
>>> 1.49571
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
with ops.name_scope(name, values=[n, z]):
|
||||
z = _get_samples(p, z, n, seed)
|
||||
|
@ -19,27 +19,19 @@ from __future__ import print_function
|
||||
|
||||
import fnmatch
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
|
||||
from setuptools import find_packages, setup, Command, Extension
|
||||
from setuptools import find_packages, setup, Command
|
||||
from setuptools.command.install import install as InstallCommandBase
|
||||
from setuptools.dist import Distribution
|
||||
|
||||
_VERSION = '0.10.0-cmake-experimental'
|
||||
|
||||
numpy_version = "1.8.2"
|
||||
if platform.system() == "Darwin":
|
||||
# There are bugs with numpy pip installation on OS X prior to
|
||||
# 1.10.1, so on mac we require a higher version than on other
|
||||
# platforms.
|
||||
numpy_version = "1.10.1"
|
||||
|
||||
REQUIRED_PACKAGES = [
|
||||
'numpy >= %s' % numpy_version,
|
||||
'numpy >= 1.11.0',
|
||||
'six >= 1.10.0',
|
||||
'protobuf == 3.0.0b2',
|
||||
'protobuf == 3.0.0',
|
||||
]
|
||||
|
||||
# python3 requires wheel 0.26
|
||||
|
40
tensorflow/contrib/crf/BUILD
Normal file
40
tensorflow/contrib/crf/BUILD
Normal file
@ -0,0 +1,40 @@
|
||||
# Description:
|
||||
# Contains classes to construct a CRF layer
|
||||
# APIs here are meant to evolve over time.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
|
||||
py_library(
|
||||
name = "crf_py",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "crf_test",
|
||||
srcs = ["python/kernel_tests/crf_test.py"],
|
||||
additional_deps = [
|
||||
":crf_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
76
tensorflow/contrib/crf/README.md
Normal file
76
tensorflow/contrib/crf/README.md
Normal file
@ -0,0 +1,76 @@
|
||||
# CRF
|
||||
|
||||
The CRF module implements a linear-chain CRF layer for learning to predict tag sequences. This variant of the CRF is factored into unary potentials for every element in the sequence and binary potentials for every transition between output tags.
|
||||
|
||||
### Usage
|
||||
|
||||
Below is an example of the API, which learns a CRF for some random data. The linear layer in the example can be replaced by any neural network.
|
||||
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
# Data settings.
|
||||
num_examples = 10
|
||||
num_words = 20
|
||||
num_features = 100
|
||||
num_tags = 5
|
||||
|
||||
# Random features.
|
||||
x = np.random.rand(num_examples, num_words, num_features).astype(np.float32)
|
||||
|
||||
# Random tag indices representing the gold sequence.
|
||||
y = np.random.randint(num_tags, size=[num_examples, num_words]).astype(np.int32)
|
||||
|
||||
# All sequences in this example have the same length, but they can be variable in a real model.
|
||||
sequence_lengths = np.full(num_examples, num_words - 1, dtype=np.int32)
|
||||
|
||||
# Train and evaluate the model.
|
||||
with tf.Graph().as_default():
|
||||
with tf.Session() as session:
|
||||
# Add the data to the TensorFlow graph.
|
||||
x_t = tf.constant(x)
|
||||
y_t = tf.constant(y)
|
||||
sequence_lengths_t = tf.constant(sequence_lengths)
|
||||
|
||||
# Compute unary scores from a linear layer.
|
||||
weights = tf.get_variable("weights", [num_features, num_tags])
|
||||
matricized_x_t = tf.reshape(x_t, [-1, num_features])
|
||||
matricized_unary_scores = tf.batch_matmul(matricized_x_t, weights)
|
||||
unary_scores = tf.reshape(matricized_unary_scores,
|
||||
[num_examples, num_words, num_tags])
|
||||
|
||||
# Compute the log-likelihood of the gold sequences and keep the transition
|
||||
# params for inference at test time.
|
||||
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
|
||||
unary_scores, y_t, sequence_lengths_t)
|
||||
|
||||
# Add a training op to tune the parameters.
|
||||
loss = tf.reduce_mean(-log_likelihood)
|
||||
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
|
||||
|
||||
# Train for a fixed number of iterations.
|
||||
session.run(tf.initialize_all_variables())
|
||||
for i in range(1000):
|
||||
tf_unary_scores, tf_transition_params, _ = session.run(
|
||||
[unary_scores, transition_params, train_op])
|
||||
if i % 100 == 0:
|
||||
correct_labels = 0
|
||||
total_labels = 0
|
||||
for tf_unary_scores_, y_, sequence_length_ in zip(tf_unary_scores, y,
|
||||
sequence_lengths):
|
||||
# Remove padding from the scores and tag sequence.
|
||||
tf_unary_scores_ = tf_unary_scores_[:sequence_length_]
|
||||
y_ = y_[:sequence_length_]
|
||||
|
||||
# Compute the highest scoring sequence.
|
||||
viterbi_sequence, _ = tf.contrib.crf.viterbi_decode(
|
||||
tf_unary_scores_, tf_transition_params)
|
||||
|
||||
# Evaluate word-level accuracy.
|
||||
correct_labels += np.sum(np.equal(viterbi_sequence, y_))
|
||||
total_labels += sequence_length_
|
||||
accuracy = 100.0 * correct_labels / float(total_labels)
|
||||
print("Accuracy: %.2f%%" % accuracy)
|
||||
```
|
39
tensorflow/contrib/crf/__init__.py
Normal file
39
tensorflow/contrib/crf/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Linear-chain CRF layer.
|
||||
|
||||
## This package provides functions for building a linear-chain CRF layer.
|
||||
|
||||
@@crf_sequence_score
|
||||
@@crf_log_norm
|
||||
@@crf_log_likelihood
|
||||
@@crf_unary_score
|
||||
@@crf_binary_score
|
||||
@@CrfForwardRnnCell
|
||||
@@viterbi_decode
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks
|
||||
from tensorflow.contrib.crf.python.ops.crf import crf_binary_score
|
||||
from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood
|
||||
from tensorflow.contrib.crf.python.ops.crf import crf_log_norm
|
||||
from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score
|
||||
from tensorflow.contrib.crf.python.ops.crf import crf_unary_score
|
||||
from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
|
||||
from tensorflow.contrib.crf.python.ops.crf import viterbi_decode
|
18
tensorflow/contrib/crf/python/__init__.py
Normal file
18
tensorflow/contrib/crf/python/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Linear-chain CRF."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
200
tensorflow/contrib/crf/python/kernel_tests/crf_test.py
Normal file
200
tensorflow/contrib/crf/python/kernel_tests/crf_test.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for CRF."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class CrfTest(tf.test.TestCase):
|
||||
|
||||
def testCrfSequenceScore(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
with self.test_session() as sess:
|
||||
sequence_score = tf.contrib.crf.crf_sequence_score(
|
||||
inputs=tf.expand_dims(inputs, 0),
|
||||
tag_indices=tf.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
|
||||
transition_params=tf.constant(transition_params))
|
||||
sequence_score = tf.squeeze(sequence_score, [0])
|
||||
tf_sequence_score = sess.run(sequence_score)
|
||||
expected_unary_score = sum(inputs[i][tag_indices[i]]
|
||||
for i in range(sequence_lengths))
|
||||
expected_binary_score = sum(
|
||||
transition_params[tag_indices[i], tag_indices[i + 1]]
|
||||
for i in range(sequence_lengths - 1))
|
||||
expected_sequence_score = expected_unary_score + expected_binary_score
|
||||
self.assertAllClose(tf_sequence_score, expected_sequence_score)
|
||||
|
||||
def testCrfUnaryScore(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
with self.test_session() as sess:
|
||||
unary_score = tf.contrib.crf.crf_unary_score(
|
||||
tag_indices=tf.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
|
||||
inputs=tf.expand_dims(inputs, 0))
|
||||
unary_score = tf.squeeze(unary_score, [0])
|
||||
tf_unary_score = sess.run(unary_score)
|
||||
expected_unary_score = sum(inputs[i][tag_indices[i]]
|
||||
for i in range(sequence_lengths))
|
||||
self.assertAllClose(tf_unary_score, expected_unary_score)
|
||||
|
||||
def testCrfBinaryScore(self):
|
||||
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
with self.test_session() as sess:
|
||||
binary_score = tf.contrib.crf.crf_binary_score(
|
||||
tag_indices=tf.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
|
||||
transition_params=tf.constant(transition_params))
|
||||
binary_score = tf.squeeze(binary_score, [0])
|
||||
tf_binary_score = sess.run(binary_score)
|
||||
expected_binary_score = sum(
|
||||
transition_params[tag_indices[i], tag_indices[i + 1]]
|
||||
for i in range(sequence_lengths - 1))
|
||||
self.assertAllClose(tf_binary_score, expected_binary_score)
|
||||
|
||||
def testCrfLogNorm(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
with self.test_session() as sess:
|
||||
all_sequence_scores = []
|
||||
|
||||
# Compare the dynamic program with brute force computation.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
all_sequence_scores.append(
|
||||
tf.contrib.crf.crf_sequence_score(
|
||||
inputs=tf.expand_dims(inputs, 0),
|
||||
tag_indices=tf.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
|
||||
transition_params=tf.constant(transition_params)))
|
||||
|
||||
brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores)
|
||||
log_norm = tf.contrib.crf.crf_log_norm(
|
||||
inputs=tf.expand_dims(inputs, 0),
|
||||
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
|
||||
transition_params=tf.constant(transition_params))
|
||||
log_norm = tf.squeeze(log_norm, [0])
|
||||
tf_brute_force_log_norm, tf_log_norm = sess.run(
|
||||
[brute_force_log_norm, log_norm])
|
||||
|
||||
self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
|
||||
|
||||
def testCrfLogLikelihood(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
with self.test_session() as sess:
|
||||
all_sequence_log_likelihoods = []
|
||||
|
||||
# Make sure all probabilities sum to 1.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
sequence_log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(
|
||||
inputs=tf.expand_dims(inputs, 0),
|
||||
tag_indices=tf.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
|
||||
transition_params=tf.constant(transition_params))
|
||||
all_sequence_log_likelihoods.append(sequence_log_likelihood)
|
||||
total_log_likelihood = tf.reduce_logsumexp(all_sequence_log_likelihoods)
|
||||
tf_total_log_likelihood = sess.run(total_log_likelihood)
|
||||
self.assertAllClose(tf_total_log_likelihood, 0.0)
|
||||
|
||||
def testLengthsToMasks(self):
|
||||
with self.test_session() as sess:
|
||||
sequence_lengths = [4, 1, 8, 2]
|
||||
max_sequence_length = max(sequence_lengths)
|
||||
|
||||
mask = tf.contrib.crf._lengths_to_masks(sequence_lengths,
|
||||
max_sequence_length)
|
||||
tf_mask = sess.run(mask)
|
||||
self.assertEqual(len(tf_mask), len(sequence_lengths))
|
||||
for m, l in zip(tf_mask, sequence_lengths):
|
||||
self.assertAllEqual(m[:l], [1] * l)
|
||||
self.assertAllEqual(m[l:], [0] * (len(m) - l))
|
||||
|
||||
def testViterbiDecode(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
|
||||
with self.test_session() as sess:
|
||||
all_sequence_scores = []
|
||||
all_sequences = []
|
||||
|
||||
# Compare the dynamic program with brute force computation.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
all_sequences.append(tag_indices)
|
||||
sequence_score = tf.contrib.crf.crf_sequence_score(
|
||||
inputs=tf.expand_dims(inputs, 0),
|
||||
tag_indices=tf.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
|
||||
transition_params=tf.constant(transition_params))
|
||||
sequence_score = tf.squeeze(sequence_score, [0])
|
||||
all_sequence_scores.append(sequence_score)
|
||||
|
||||
tf_all_sequence_scores = sess.run(all_sequence_scores)
|
||||
|
||||
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
|
||||
expected_max_sequence = all_sequences[expected_max_sequence_index]
|
||||
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
|
||||
|
||||
actual_max_sequence, actual_max_score = tf.contrib.crf.viterbi_decode(
|
||||
inputs[:sequence_lengths], transition_params)
|
||||
|
||||
self.assertAllClose(actual_max_score, expected_max_score)
|
||||
self.assertEqual(actual_max_sequence,
|
||||
expected_max_sequence[:sequence_lengths])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
18
tensorflow/contrib/crf/python/ops/__init__.py
Normal file
18
tensorflow/contrib/crf/python/ops/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for building a linear-chain CRF layer."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
311
tensorflow/contrib/crf/python/ops/crf.py
Normal file
311
tensorflow/contrib/crf/python/ops/crf.py
Normal file
@ -0,0 +1,311 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module for constructing a linear-chain CRF.
|
||||
|
||||
The following snippet is an example of a CRF layer on top of a batched sequence
|
||||
of unary scores (logits for every word). This example also decodes the most
|
||||
likely sequence at test time:
|
||||
|
||||
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
|
||||
unary_scores, gold_tags, sequence_lengths)
|
||||
loss = tf.reduce_mean(-log_likelihood)
|
||||
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
|
||||
|
||||
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
|
||||
[unary_scores, sequence_lengths, transition_params, train_op])
|
||||
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
|
||||
tf_sequence_lengths):
|
||||
# Remove padding.
|
||||
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
|
||||
|
||||
# Compute the highest score and its tag sequence.
|
||||
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
|
||||
tf_unary_scores_, tf_transition_params)
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
|
||||
__all__ = ["crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
|
||||
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
|
||||
"viterbi_decode"]
|
||||
|
||||
|
||||
def _lengths_to_masks(lengths, max_length):
|
||||
"""Creates a binary matrix that can be used to mask away padding.
|
||||
|
||||
Args:
|
||||
lengths: A vector of integers representing lengths.
|
||||
max_length: An integer indicating the maximum length. All values in
|
||||
lengths should be less than max_length.
|
||||
Returns:
|
||||
masks: Masks that can be used to get rid of padding.
|
||||
"""
|
||||
tiled_ranges = array_ops.tile(
|
||||
array_ops.expand_dims(math_ops.range(max_length), 0),
|
||||
[array_ops.shape(lengths)[0], 1])
|
||||
lengths = array_ops.expand_dims(lengths, 1)
|
||||
masks = math_ops.to_float(
|
||||
math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths))
|
||||
return masks
|
||||
|
||||
|
||||
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
|
||||
transition_params):
|
||||
"""Computes the unnormalized score for a tag sequence.
|
||||
|
||||
Args:
|
||||
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
|
||||
to use as input to the CRF layer.
|
||||
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
|
||||
compute the unnormalized score.
|
||||
sequence_lengths: A [batch_size] vector of true sequence lengths.
|
||||
transition_params: A [num_tags, num_tags] transition matrix.
|
||||
Returns:
|
||||
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
|
||||
"""
|
||||
# Compute the scores of the given tag sequence.
|
||||
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
|
||||
binary_scores = crf_binary_score(tag_indices, sequence_lengths,
|
||||
transition_params)
|
||||
sequence_scores = unary_scores + binary_scores
|
||||
return sequence_scores
|
||||
|
||||
|
||||
def crf_log_norm(inputs, sequence_lengths, transition_params):
|
||||
"""Computes the normalization for a CRF.
|
||||
|
||||
Args:
|
||||
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
|
||||
to use as input to the CRF layer.
|
||||
sequence_lengths: A [batch_size] vector of true sequence lengths.
|
||||
transition_params: A [num_tags, num_tags] transition matrix.
|
||||
Returns:
|
||||
log_norm: A [batch_size] vector of normalizers for a CRF.
|
||||
"""
|
||||
# Split up the first and rest of the inputs in preparation for the forward
|
||||
# algorithm.
|
||||
first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
|
||||
first_input = array_ops.squeeze(first_input, [1])
|
||||
rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
|
||||
|
||||
# Compute the alpha values in the forward algorithm in order to get the
|
||||
# partition function.
|
||||
forward_cell = CrfForwardRnnCell(transition_params)
|
||||
_, alphas = rnn.dynamic_rnn(
|
||||
cell=forward_cell,
|
||||
inputs=rest_of_input,
|
||||
sequence_length=sequence_lengths - 1,
|
||||
initial_state=first_input,
|
||||
dtype=dtypes.float32)
|
||||
log_norm = math_ops.reduce_logsumexp(alphas, [1])
|
||||
return log_norm
|
||||
|
||||
|
||||
def crf_log_likelihood(inputs,
|
||||
tag_indices,
|
||||
sequence_lengths,
|
||||
transition_params=None):
|
||||
"""Computes the log-likehood of tag sequences in a CRF.
|
||||
|
||||
Args:
|
||||
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
|
||||
to use as input to the CRF layer.
|
||||
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
|
||||
compute the log-likehood.
|
||||
sequence_lengths: A [batch_size] vector of true sequence lengths.
|
||||
transition_params: A [num_tags, num_tags] transition matrix, if available.
|
||||
Returns:
|
||||
log_likelihood: A scalar containing the log-likelihood of the given sequence
|
||||
of tag indices.
|
||||
transition_params: A [num_tags, num_tags] transition matrix. This is either
|
||||
provided by the caller or created in this function.
|
||||
"""
|
||||
# Get shape information.
|
||||
num_tags = inputs.get_shape()[2].value
|
||||
|
||||
# Get the transition matrix if not provided.
|
||||
if transition_params is None:
|
||||
transition_params = vs.get_variable("transitions", [num_tags, num_tags])
|
||||
|
||||
sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
|
||||
transition_params)
|
||||
log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
|
||||
|
||||
# Normalize the scores to get the log-likelihood.
|
||||
log_likelihood = sequence_scores - log_norm
|
||||
return log_likelihood, transition_params
|
||||
|
||||
|
||||
def crf_unary_score(tag_indices, sequence_lengths, inputs):
|
||||
"""Computes the unary scores of tag sequences.
|
||||
|
||||
Args:
|
||||
tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
|
||||
sequence_lengths: A [batch_size] vector of true sequence lengths.
|
||||
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
|
||||
Returns:
|
||||
unary_scores: A [batch_size] vector of unary scores.
|
||||
"""
|
||||
batch_size = array_ops.shape(inputs)[0]
|
||||
max_seq_len = array_ops.shape(inputs)[1]
|
||||
num_tags = array_ops.shape(inputs)[2]
|
||||
|
||||
flattened_inputs = array_ops.reshape(inputs, [-1])
|
||||
|
||||
offsets = array_ops.expand_dims(
|
||||
math_ops.range(batch_size) * max_seq_len * num_tags, 1)
|
||||
offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
|
||||
flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
|
||||
|
||||
unary_scores = array_ops.reshape(
|
||||
array_ops.gather(flattened_inputs, flattened_tag_indices),
|
||||
[batch_size, max_seq_len])
|
||||
|
||||
masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
|
||||
|
||||
unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
|
||||
return unary_scores
|
||||
|
||||
|
||||
def crf_binary_score(tag_indices, sequence_lengths, transition_params):
|
||||
"""Computes the binary scores of tag sequences.
|
||||
|
||||
Args:
|
||||
tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
|
||||
sequence_lengths: A [batch_size] vector of true sequence lengths.
|
||||
transition_params: A [num_tags, num_tags] matrix of binary potentials.
|
||||
Returns:
|
||||
binary_scores: A [batch_size] vector of binary scores.
|
||||
"""
|
||||
# Get shape information.
|
||||
num_tags = transition_params.get_shape()[0]
|
||||
num_transitions = array_ops.shape(tag_indices)[1] - 1
|
||||
|
||||
# Truncate by one on each side of the sequence to get the start and end
|
||||
# indices of each transition.
|
||||
start_tag_indices = array_ops.slice(tag_indices, [0, 0],
|
||||
[-1, num_transitions])
|
||||
end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])
|
||||
|
||||
# Encode the indices in a flattened representation.
|
||||
flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
|
||||
flattened_transition_params = array_ops.reshape(transition_params, [-1])
|
||||
|
||||
# Get the binary scores based on the flattened representation.
|
||||
binary_scores = array_ops.gather(flattened_transition_params,
|
||||
flattened_transition_indices)
|
||||
|
||||
masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
|
||||
truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
|
||||
binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
|
||||
return binary_scores
|
||||
|
||||
|
||||
class CrfForwardRnnCell(rnn_cell.RNNCell):
|
||||
"""Computes the alpha values in a linear-chain CRF.
|
||||
|
||||
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
|
||||
"""
|
||||
|
||||
def __init__(self, transition_params):
|
||||
"""Initialize the CrfForwardRnnCell.
|
||||
|
||||
Args:
|
||||
transition_params: A [num_tags, num_tags] matrix of binary potentials.
|
||||
This matrix is expanded into a [1, num_tags, num_tags] in preparation
|
||||
for the broadcast summation occurring within the cell.
|
||||
"""
|
||||
self._transition_params = array_ops.expand_dims(transition_params, 0)
|
||||
self._num_tags = transition_params.get_shape()[0].value
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._num_tags
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._num_tags
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Build the CrfForwardRnnCell.
|
||||
|
||||
Args:
|
||||
inputs: A [batch_size, num_tags] matrix of unary potentials.
|
||||
state: A [batch_size, num_tags] matrix containing the previous alpha
|
||||
values.
|
||||
scope: Unused variable scope of this cell.
|
||||
|
||||
Returns:
|
||||
new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
|
||||
values containing the new alpha values.
|
||||
"""
|
||||
state = array_ops.expand_dims(state, 2)
|
||||
|
||||
# This addition op broadcasts self._transitions_params along the zeroth
|
||||
# dimension and state along the second dimension. This performs the
|
||||
# multiplication of previous alpha values and the current binary potentials
|
||||
# in log space.
|
||||
transition_scores = state + self._transition_params
|
||||
new_alphas = inputs + math_ops.reduce_logsumexp(transition_scores, [1])
|
||||
|
||||
# Both the state and the output of this RNN cell contain the alphas values.
|
||||
# The output value is currently unused and simply satisfies the RNN API.
|
||||
# This could be useful in the future if we need to compute marginal
|
||||
# probabilities, which would require the accumulated alpha values at every
|
||||
# time step.
|
||||
return new_alphas, new_alphas
|
||||
|
||||
|
||||
def viterbi_decode(score, transition_params):
|
||||
"""Decode the highest scoring sequence of tags outside of TensorFlow.
|
||||
|
||||
This should only be used at test time.
|
||||
|
||||
Args:
|
||||
score: A [seq_len, num_tags] matrix of unary potentials.
|
||||
transition_params: A [num_tags, num_tags] matrix of binary potentials.
|
||||
|
||||
Returns:
|
||||
viterbi: A [seq_len] list of integers containing the highest scoring tag
|
||||
indicies.
|
||||
viterbi_score: A float containing the score for the viterbi sequence.
|
||||
"""
|
||||
trellis = np.zeros_like(score)
|
||||
backpointers = np.zeros_like(score, dtype=np.int32)
|
||||
trellis[0] = score[0]
|
||||
|
||||
for t in range(1, score.shape[0]):
|
||||
v = np.expand_dims(trellis[t - 1], 1) + transition_params
|
||||
trellis[t] = score[t] + np.max(v, 0)
|
||||
backpointers[t] = np.argmax(v, 0)
|
||||
|
||||
viterbi = [np.argmax(trellis[-1])]
|
||||
for bp in reversed(backpointers[1:]):
|
||||
viterbi.append(bp[viterbi[-1]])
|
||||
viterbi.reverse()
|
||||
|
||||
viterbi_score = np.max(trellis[-1])
|
||||
return viterbi, viterbi_score
|
@ -58,13 +58,6 @@ class KLTest(tf.test.TestCase):
|
||||
self.assertAllEqual([float("nan")], kl_ok.eval())
|
||||
|
||||
def testRegistrationFailures(self):
|
||||
with self.assertRaisesRegexp(TypeError, "is not a subclass of"):
|
||||
tf.contrib.distributions.RegisterKL(
|
||||
tf.contrib.distributions.Normal, object)(lambda x: x)
|
||||
with self.assertRaisesRegexp(TypeError, "is not a subclass of"):
|
||||
tf.contrib.distributions.RegisterKL(
|
||||
object, tf.contrib.distributions.Normal)(lambda x: x)
|
||||
|
||||
class MyDist(tf.contrib.distributions.Normal):
|
||||
pass
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -69,9 +68,9 @@ def make_univariate_mixture(batch_shape, num_components):
|
||||
logits = tf.random_uniform(
|
||||
list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50.
|
||||
components = [
|
||||
(distributions_py.Normal,
|
||||
{"mu": np.float32(np.random.randn(*list(batch_shape))),
|
||||
"sigma": np.float32(10 * np.random.rand(*list(batch_shape)))})
|
||||
distributions_py.Normal(
|
||||
mu=np.float32(np.random.randn(*list(batch_shape))),
|
||||
sigma=np.float32(10 * np.random.rand(*list(batch_shape))))
|
||||
for _ in range(num_components)
|
||||
]
|
||||
cat = distributions_py.Categorical(logits, dtype=tf.int32)
|
||||
@ -82,10 +81,10 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape):
|
||||
logits = tf.random_uniform(
|
||||
list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50.
|
||||
components = [
|
||||
(distributions_py.MultivariateNormalDiag,
|
||||
{"mu": np.float32(np.random.randn(*list(batch_shape + event_shape))),
|
||||
"diag_stdev": np.float32(10 * np.random.rand(
|
||||
*list(batch_shape + event_shape)))})
|
||||
distributions_py.MultivariateNormalDiag(
|
||||
mu=np.float32(np.random.randn(*list(batch_shape + event_shape))),
|
||||
diag_stdev=np.float32(10 * np.random.rand(
|
||||
*list(batch_shape + event_shape))))
|
||||
for _ in range(num_components)
|
||||
]
|
||||
cat = distributions_py.Categorical(logits, dtype=tf.int32)
|
||||
@ -116,7 +115,7 @@ class MixtureTest(tf.test.TestCase):
|
||||
r"cat.num_classes != len"):
|
||||
distributions_py.Mixture(
|
||||
distributions_py.Categorical([0.1, 0.5]), # 2 classes
|
||||
[(distributions_py.Normal, {"mu": 1.0, "sigma": 2.0})])
|
||||
[distributions_py.Normal(mu=1.0, sigma=2.0)])
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError, r"\(\) and \(2,\) are not compatible"):
|
||||
# The value error is raised because the batch shapes of the
|
||||
@ -124,13 +123,13 @@ class MixtureTest(tf.test.TestCase):
|
||||
# vector of size (2,).
|
||||
distributions_py.Mixture(
|
||||
distributions_py.Categorical([-0.5, 0.5]), # scalar batch
|
||||
[(distributions_py.Normal, {"mu": 1.0, "sigma": 2.0}), # scalar dist
|
||||
(distributions_py.Normal, {"mu": [1.0, 1.0], "sigma": [2.0, 2.0]})])
|
||||
[distributions_py.Normal(mu=1.0, sigma=2.0), # scalar dist
|
||||
distributions_py.Normal(mu=[1.0, 1.0], sigma=[2.0, 2.0])])
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"):
|
||||
cat_logits = tf.placeholder(shape=[1, None], dtype=tf.int32)
|
||||
distributions_py.Mixture(
|
||||
distributions_py.Categorical(cat_logits),
|
||||
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]})])
|
||||
[distributions_py.Normal(mu=[1.0], sigma=[2.0])])
|
||||
|
||||
def testBrokenShapesDynamic(self):
|
||||
with self.test_session():
|
||||
@ -138,8 +137,8 @@ class MixtureTest(tf.test.TestCase):
|
||||
d1_param = tf.placeholder(dtype=tf.float32)
|
||||
d = distributions_py.Mixture(
|
||||
distributions_py.Categorical([0.1, 0.2]),
|
||||
[(distributions_py.Normal, {"mu": d0_param, "sigma": d0_param}),
|
||||
(distributions_py.Normal, {"mu": d1_param, "sigma": d1_param})],
|
||||
[distributions_py.Normal(mu=d0_param, sigma=d0_param),
|
||||
distributions_py.Normal(mu=d1_param, sigma=d1_param)],
|
||||
validate_args=True)
|
||||
with self.assertRaisesOpError(r"batch shape must match"):
|
||||
d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]})
|
||||
@ -150,42 +149,24 @@ class MixtureTest(tf.test.TestCase):
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"):
|
||||
distributions_py.Mixture(None, [])
|
||||
cat = distributions_py.Categorical([0.3, 0.2])
|
||||
# components must be a list of tuples
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
|
||||
# components must be a list of distributions
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
TypeError, "all .* must be Distribution instances"):
|
||||
distributions_py.Mixture(cat, [None])
|
||||
# components tuples must be size 2
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
|
||||
distributions_py.Mixture(cat, [tuple()])
|
||||
# components tuples must be size 2
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
|
||||
distributions_py.Mixture(cat, [(None)])
|
||||
# components tuples must be of the form (callable, dict)
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
|
||||
distributions_py.Mixture(cat, [(None, None)])
|
||||
# components tuples must be size 2
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
|
||||
distributions_py.Mixture(cat, [(None, None, None)])
|
||||
# components tuples must be of the form (callable, dict)
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
|
||||
distributions_py.Mixture(cat, [(lambda x: x, None)])
|
||||
# components tuples must be of the form (callable, dict)
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
|
||||
distributions_py.Mixture(cat, [(None, {})])
|
||||
with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"):
|
||||
distributions_py.Mixture(
|
||||
cat,
|
||||
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]}),
|
||||
(distributions_py.Normal, {"mu": [np.float16(1.0)],
|
||||
"sigma": [np.float16(2.0)]})])
|
||||
[distributions_py.Normal(mu=[1.0], sigma=[2.0]),
|
||||
distributions_py.Normal(mu=[np.float16(1.0)],
|
||||
sigma=[np.float16(2.0)])])
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"):
|
||||
distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None)
|
||||
with self.assertRaisesWithPredicateMatch(TypeError,
|
||||
"either be continuous or not"):
|
||||
distributions_py.Mixture(
|
||||
cat,
|
||||
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]}),
|
||||
(functools.partial(distributions_py.Bernoulli, dtype=tf.float32),
|
||||
{"logits": [1.0]})])
|
||||
[distributions_py.Normal(mu=[1.0], sigma=[2.0]),
|
||||
distributions_py.Bernoulli(dtype=tf.float32, logits=[1.0])])
|
||||
|
||||
def testMeanUnivariate(self):
|
||||
with self.test_session() as sess:
|
||||
@ -196,7 +177,7 @@ class MixtureTest(tf.test.TestCase):
|
||||
self.assertEqual(batch_shape, mean.get_shape())
|
||||
|
||||
cat_probs = tf.nn.softmax(dist.cat.logits)
|
||||
dist_means = [d.mean() for d in dist.distributions]
|
||||
dist_means = [d.mean() for d in dist.components]
|
||||
|
||||
mean_value, cat_probs_value, dist_means_value = sess.run(
|
||||
[mean, cat_probs, dist_means])
|
||||
@ -217,7 +198,7 @@ class MixtureTest(tf.test.TestCase):
|
||||
self.assertEqual(batch_shape + (4,), mean.get_shape())
|
||||
|
||||
cat_probs = tf.nn.softmax(dist.cat.logits)
|
||||
dist_means = [d.mean() for d in dist.distributions]
|
||||
dist_means = [d.mean() for d in dist.components]
|
||||
|
||||
mean_value, cat_probs_value, dist_means_value = sess.run(
|
||||
[mean, cat_probs, dist_means])
|
||||
@ -243,7 +224,7 @@ class MixtureTest(tf.test.TestCase):
|
||||
|
||||
self.assertEqual(x.shape, p_x.get_shape())
|
||||
cat_probs = tf.nn.softmax([dist.cat.logits])[0]
|
||||
dist_probs = [d.prob(x) for d in dist.distributions]
|
||||
dist_probs = [d.prob(x) for d in dist.components]
|
||||
|
||||
p_x_value, cat_probs_value, dist_probs_value = sess.run(
|
||||
[p_x, cat_probs, dist_probs])
|
||||
@ -269,7 +250,7 @@ class MixtureTest(tf.test.TestCase):
|
||||
self.assertEqual(x.shape[:-1], p_x.get_shape())
|
||||
|
||||
cat_probs = tf.nn.softmax([dist.cat.logits])[0]
|
||||
dist_probs = [d.prob(x) for d in dist.distributions]
|
||||
dist_probs = [d.prob(x) for d in dist.components]
|
||||
|
||||
p_x_value, cat_probs_value, dist_probs_value = sess.run(
|
||||
[p_x, cat_probs, dist_probs])
|
||||
@ -292,7 +273,7 @@ class MixtureTest(tf.test.TestCase):
|
||||
self.assertEqual(x.shape, p_x.get_shape())
|
||||
|
||||
cat_probs = tf.nn.softmax(dist.cat.logits)
|
||||
dist_probs = [d.prob(x) for d in dist.distributions]
|
||||
dist_probs = [d.prob(x) for d in dist.components]
|
||||
|
||||
p_x_value, cat_probs_value, dist_probs_value = sess.run(
|
||||
[p_x, cat_probs, dist_probs])
|
||||
@ -318,7 +299,7 @@ class MixtureTest(tf.test.TestCase):
|
||||
self.assertEqual(x.shape[:-1], p_x.get_shape())
|
||||
|
||||
cat_probs = tf.nn.softmax(dist.cat.logits)
|
||||
dist_probs = [d.prob(x) for d in dist.distributions]
|
||||
dist_probs = [d.prob(x) for d in dist.components]
|
||||
|
||||
p_x_value, cat_probs_value, dist_probs_value = sess.run(
|
||||
[p_x, cat_probs, dist_probs])
|
||||
@ -430,7 +411,7 @@ class MixtureTest(tf.test.TestCase):
|
||||
self.assertEqual(batch_shape, entropy_lower_bound.get_shape())
|
||||
|
||||
cat_probs = tf.nn.softmax(dist.cat.logits)
|
||||
dist_entropy = [d.entropy() for d in dist.distributions]
|
||||
dist_entropy = [d.entropy() for d in dist.components]
|
||||
|
||||
entropy_lower_bound_value, cat_probs_value, dist_entropy_value = (
|
||||
sess.run([entropy_lower_bound, cat_probs, dist_entropy]))
|
||||
@ -486,8 +467,7 @@ class MixtureBenchmark(tf.test.Benchmark):
|
||||
tf.Variable(np.random.rand(batch_size, num_features))
|
||||
for _ in range(num_components)]
|
||||
components = list(
|
||||
(distributions_py.MultivariateNormalDiag,
|
||||
{"mu": mu, "diag_stdev": sigma})
|
||||
distributions_py.MultivariateNormalDiag(mu=mu, diag_stdev=sigma)
|
||||
for (mu, sigma) in zip(mus, sigmas))
|
||||
return distributions_py.Mixture(cat, components)
|
||||
|
||||
@ -524,8 +504,7 @@ class MixtureBenchmark(tf.test.Benchmark):
|
||||
psd(np.random.rand(batch_size, num_features, num_features)))
|
||||
for _ in range(num_components)]
|
||||
components = list(
|
||||
(distributions_py.MultivariateNormalFull,
|
||||
{"mu": mu, "sigma": sigma})
|
||||
distributions_py.MultivariateNormalFull(mu=mu, sigma=sigma)
|
||||
for (mu, sigma) in zip(mus, sigmas))
|
||||
return distributions_py.Mixture(cat, components)
|
||||
|
||||
|
@ -33,7 +33,7 @@ class QuantizedDistributionTest(tf.test.TestCase):
|
||||
self.assertTrue(np.isfinite(array).all())
|
||||
|
||||
def test_quantization_of_uniform_with_cutoffs_having_no_effect(self):
|
||||
with self.test_session():
|
||||
with self.test_session() as sess:
|
||||
# The Quantized uniform with cutoffs == None divides the real line into:
|
||||
# R = ...(-1, 0](0, 1](1, 2](2, 3](3, 4]...
|
||||
# j = ... 0 1 2 3 4 ...
|
||||
@ -60,34 +60,38 @@ class QuantizedDistributionTest(tf.test.TestCase):
|
||||
b=3.0)
|
||||
|
||||
# pmf
|
||||
pmf_n1, pmf_0, pmf_1, pmf_2, pmf_3, pmf_4, pmf_5 = sess.run(
|
||||
qdist.pmf([-1., 0., 1., 2., 3., 4., 5.]))
|
||||
# uniform had no mass below -1.
|
||||
self.assertAllClose(0., qdist.pmf(-1.).eval())
|
||||
self.assertAllClose(0., pmf_n1)
|
||||
# uniform had no mass below 0.
|
||||
self.assertAllClose(0., qdist.pmf(0.).eval())
|
||||
self.assertAllClose(0., pmf_0)
|
||||
# uniform put 1/3 of its mass in each of (0, 1], (1, 2], (2, 3],
|
||||
# which are the intervals j = 1, 2, 3.
|
||||
self.assertAllClose(1 / 3, qdist.pmf(1.).eval())
|
||||
self.assertAllClose(1 / 3, qdist.pmf(2.).eval())
|
||||
self.assertAllClose(1 / 3, qdist.pmf(3.).eval())
|
||||
self.assertAllClose(1 / 3, pmf_1)
|
||||
self.assertAllClose(1 / 3, pmf_2)
|
||||
self.assertAllClose(1 / 3, pmf_3)
|
||||
# uniform had no mass in (3, 4] or (4, 5], which are j = 4, 5.
|
||||
self.assertAllClose(0 / 3, qdist.pmf(4.).eval())
|
||||
self.assertAllClose(0 / 3, qdist.pmf(5.).eval())
|
||||
self.assertAllClose(0 / 3, pmf_4)
|
||||
self.assertAllClose(0 / 3, pmf_5)
|
||||
|
||||
# cdf
|
||||
self.assertAllClose(0., qdist.cdf(-1.).eval())
|
||||
self.assertAllClose(0., qdist.cdf(0.).eval())
|
||||
self.assertAllClose(1 / 3, qdist.cdf(1.).eval())
|
||||
self.assertAllClose(2 / 3, qdist.cdf(2.).eval())
|
||||
cdf_n1, cdf_0, cdf_1, cdf_2, cdf_2p5, cdf_3, cdf_4, cdf_5 = sess.run(
|
||||
qdist.cdf([-1., 0., 1., 2., 2.5, 3., 4., 5.]))
|
||||
self.assertAllClose(0., cdf_n1)
|
||||
self.assertAllClose(0., cdf_0)
|
||||
self.assertAllClose(1 / 3, cdf_1)
|
||||
self.assertAllClose(2 / 3, cdf_2)
|
||||
# Note fractional values allowed for cdfs of discrete distributions.
|
||||
# And adding 0.5 makes no difference because the quantized dist has
|
||||
# mass only on the integers, never in between.
|
||||
self.assertAllClose(2 / 3, qdist.cdf(2.5).eval())
|
||||
self.assertAllClose(3 / 3, qdist.cdf(3.).eval())
|
||||
self.assertAllClose(3 / 3, qdist.cdf(4.).eval())
|
||||
self.assertAllClose(3 / 3, qdist.cdf(5.).eval())
|
||||
self.assertAllClose(2 / 3, cdf_2p5)
|
||||
self.assertAllClose(3 / 3, cdf_3)
|
||||
self.assertAllClose(3 / 3, cdf_4)
|
||||
self.assertAllClose(3 / 3, cdf_5)
|
||||
|
||||
def test_quantization_of_uniform_with_cutoffs_in_the_middle(self):
|
||||
with self.test_session():
|
||||
with self.test_session() as sess:
|
||||
# The uniform is supported on [-3, 3]
|
||||
# Consider partitions the real line in intervals
|
||||
# ...(-3, -2](-2, -1](-1, 0](0, 1](1, 2](2, 3] ...
|
||||
@ -103,25 +107,27 @@ class QuantizedDistributionTest(tf.test.TestCase):
|
||||
b=3.0)
|
||||
|
||||
# pmf
|
||||
cdf_n3, cdf_n2, cdf_n1, cdf_0, cdf_0p5, cdf_1, cdf_10 = sess.run(
|
||||
qdist.cdf([-3., -2., -1., 0., 0.5, 1.0, 10.0]))
|
||||
# Uniform had no mass on (-4, -3] or (-3, -2]
|
||||
self.assertAllClose(0., qdist.cdf(-3.).eval())
|
||||
self.assertAllClose(0., qdist.cdf(-2.).eval())
|
||||
self.assertAllClose(0., cdf_n3)
|
||||
self.assertAllClose(0., cdf_n2)
|
||||
# Uniform had 1/6 of its mass in each of (-3, -2], and (-2, -1], which
|
||||
# were collapsed into (-infty, -1], which is now the "-1" interval.
|
||||
self.assertAllClose(1 / 3, qdist.cdf(-1.).eval())
|
||||
self.assertAllClose(1 / 3, cdf_n1)
|
||||
# The j=0 interval contained mass from (-3, 0], which is 1/2 of the
|
||||
# uniform's mass.
|
||||
self.assertAllClose(1 / 2, qdist.cdf(0.).eval())
|
||||
self.assertAllClose(1 / 2, cdf_0)
|
||||
# Adding 0.5 makes no difference because the quantized dist has mass on
|
||||
# the integers, not in between them.
|
||||
self.assertAllClose(1 / 2, qdist.cdf(0.5).eval())
|
||||
self.assertAllClose(1 / 2, cdf_0p5)
|
||||
# After applying the cutoff, all mass was either in the interval
|
||||
# (0, infty), or below. (0, infty) is the interval indexed by j=1,
|
||||
# so pmf(1) should equal 1.
|
||||
self.assertAllClose(1., qdist.cdf(1.0).eval())
|
||||
self.assertAllClose(1., cdf_1)
|
||||
# Since no mass of qdist is above 1,
|
||||
# pmf(10) = P[Y <= 10] = P[Y <= 1] = pmf(1).
|
||||
self.assertAllClose(1., qdist.cdf(10.0).eval())
|
||||
self.assertAllClose(1., cdf_10)
|
||||
|
||||
def test_quantization_of_batch_of_uniforms(self):
|
||||
batch_shape = (5, 5)
|
||||
@ -231,10 +237,12 @@ class QuantizedDistributionTest(tf.test.TestCase):
|
||||
# The smallest value the samples can take on is 1, which corresponds to
|
||||
# the interval (0, 1]. Recall we use ceiling in the sampling definition.
|
||||
self.assertLess(0.5, samps.min())
|
||||
for x in range(1, 10):
|
||||
x_vals = np.arange(1, 11).astype(np.float32)
|
||||
pmf_vals = qdist.pmf(x_vals).eval()
|
||||
for ii in range(10):
|
||||
self.assertAllClose(
|
||||
qdist.pmf(float(x)).eval(),
|
||||
(samps == x).mean(),
|
||||
pmf_vals[ii],
|
||||
(samps == x_vals[ii]).mean(),
|
||||
atol=std_err_bound)
|
||||
|
||||
def test_normal_cdf_and_survival_function(self):
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -32,8 +31,8 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
|
||||
"""Get the KL-divergence KL(dist_a || dist_b).
|
||||
|
||||
Args:
|
||||
dist_a: instance of distributions.Distribution.
|
||||
dist_b: instance of distributions.Distribution.
|
||||
dist_a: The first distribution.
|
||||
dist_b: The second distribution.
|
||||
allow_nan: If `False` (default), a runtime error is raised
|
||||
if the KL returns NaN values for any batch entry of the given
|
||||
distributions. If `True`, the KL may return a NaN for the given entry.
|
||||
@ -43,18 +42,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
|
||||
A Tensor with the batchwise KL-divergence between dist_a and dist_b.
|
||||
|
||||
Raises:
|
||||
TypeError: If dist_a or dist_b is not an instance of Distribution.
|
||||
NotImplementedError: If no KL method is defined for distribution types
|
||||
of dist_a and dist_b.
|
||||
"""
|
||||
if not isinstance(dist_a, distribution.Distribution):
|
||||
raise TypeError(
|
||||
"dist_a is not an instance of Distribution, received type: %s"
|
||||
% type(dist_a))
|
||||
if not isinstance(dist_b, distribution.Distribution):
|
||||
raise TypeError(
|
||||
"dist_b is not an instance of Distribution, received type: %s"
|
||||
% type(dist_b))
|
||||
kl_fn = _DIVERGENCES.get((type(dist_a), type(dist_b)), None)
|
||||
if kl_fn is None:
|
||||
raise NotImplementedError(
|
||||
@ -94,16 +84,7 @@ class RegisterKL(object):
|
||||
Args:
|
||||
dist_cls_a: the class of the first argument of the KL divergence.
|
||||
dist_cls_b: the class of the second argument of the KL divergence.
|
||||
|
||||
Raises:
|
||||
TypeError: if dist_cls_a or dist_cls_b are not subclasses of
|
||||
Distribution.
|
||||
"""
|
||||
|
||||
if not issubclass(dist_cls_a, distribution.Distribution):
|
||||
raise TypeError("%s is not a subclass of Distribution" % dist_cls_a)
|
||||
if not issubclass(dist_cls_b, distribution.Distribution):
|
||||
raise TypeError("%s is not a subclass of Distribution" % dist_cls_b)
|
||||
self._key = (dist_cls_a, dist_cls_b)
|
||||
|
||||
def __call__(self, kl_fn):
|
||||
|
@ -56,43 +56,15 @@ class Mixture(distribution.Distribution):
|
||||
all having matching dtype, batch shape, event shape, and continuity
|
||||
properties (the components).
|
||||
|
||||
The user does not pass the list of distributions directly, but rather a
|
||||
list of `(constructor, batch_tensor_params_dict)` pairs,
|
||||
called `components`. The list of distributions is created via:
|
||||
|
||||
```python
|
||||
distributions = [
|
||||
c(**params_dict) for (c, params_dict) in zip(*components)
|
||||
]
|
||||
```
|
||||
|
||||
This form allows for certain types of batch-shape optimizations within
|
||||
this class.
|
||||
|
||||
An example of `components`:
|
||||
|
||||
```python
|
||||
components = [
|
||||
(tf.contrib.distributions.Normal, {"mu": 3.0, "sigma": 1.0}),
|
||||
(functools.partial(tf.contrib.distributions.Normal, validate_args=False),
|
||||
{"mu": 3.0, "sigma": 2.0}),
|
||||
(tf.contrib.distributions.Normal.from_params,
|
||||
{"mu": 1.0, "sigma": -1.0})
|
||||
]
|
||||
```
|
||||
|
||||
The `num_classes` of `cat` must be possible to infer at graph construction
|
||||
time and match `len(distributions)`.
|
||||
time and match `len(components)`.
|
||||
|
||||
Args:
|
||||
cat: A `Categorical` distribution instance, representing the probabilities
|
||||
of `distributions`.
|
||||
components: A list or tuple of `(constructor, batch_tensor_params)`
|
||||
tuples. The `constructor` must be a callable, and `batch_tensor_params`
|
||||
must be a dict mapping constructor kwargs to batchwise parameters.
|
||||
Each `Distribution` instance created by calling
|
||||
`constructor(**batch_tensor_params)` must have the same type, be defined
|
||||
on the same domain, and have matching `event_shape` and `batch_shape`.
|
||||
components: A list or tuple of `Distribution` instances.
|
||||
Each instance must have the same type, be defined on the same domain,
|
||||
and have matching `event_shape` and `batch_shape`.
|
||||
validate_args: `Boolean`, default `False`. If `True`, raise a runtime
|
||||
error if batch or event ranks are inconsistent between cat and any of
|
||||
the distributions. This is only checked if the ranks cannot be
|
||||
@ -106,16 +78,13 @@ class Mixture(distribution.Distribution):
|
||||
Raises:
|
||||
TypeError: If cat is not a `Categorical`, or `components` is not
|
||||
a list or tuple, or the elements of `components` are not
|
||||
tuples of the form `(callable, dict)`, or the objects resulting
|
||||
from calling `callable(**dict)` are not instances of `Distribution`, or
|
||||
the resulting instances of `Distribution` do not have matching
|
||||
continuity properties, or do not have matching `dtype`.
|
||||
ValueError: If `components` is an empty list or tuple, or the
|
||||
distributions created from `components` do have a statically known event
|
||||
rank. If `cat.num_classes` cannot be inferred at graph creation time,
|
||||
instances of `Distribution`, or do not have matching `dtype`.
|
||||
ValueError: If `components` is an empty list or tuple, or its
|
||||
elements do not have a statically known event rank.
|
||||
If `cat.num_classes` cannot be inferred at graph creation time,
|
||||
or the constant value of `cat.num_classes` is not equal to
|
||||
`len(distributions)`, or all `distributions` and `cat` do not have
|
||||
matching static batch shapes, or all components' distributions do not
|
||||
`len(components)`, or all `components` and `cat` do not have
|
||||
matching static batch shapes, or all components do not
|
||||
have matching static event shapes.
|
||||
"""
|
||||
if not isinstance(cat, categorical.Categorical):
|
||||
@ -126,52 +95,29 @@ class Mixture(distribution.Distribution):
|
||||
if not isinstance(components, (list, tuple)):
|
||||
raise TypeError("components must be a list or tuple, but saw: %s" %
|
||||
components)
|
||||
if not all(isinstance(c, tuple) and len(c) == 2 and
|
||||
callable(c[0]) and isinstance(c[1], dict)
|
||||
for c in components):
|
||||
if not all(isinstance(c, distribution.Distribution) for c in components):
|
||||
raise TypeError(
|
||||
"all entries in components must be tuples of the form "
|
||||
"(make, params), where make is callable and params is a dict,"
|
||||
"all entries in components must be Distribution instances"
|
||||
" but saw: %s" % components)
|
||||
|
||||
def _make_tensors(d):
|
||||
return dict((k, ops.convert_to_tensor(v, name="tensor_%s" % k))
|
||||
for (k, v) in d.items())
|
||||
|
||||
with ops.name_scope(name, values=[cat.logits]):
|
||||
components_tensor_params = list((make, _make_tensors(batch_params))
|
||||
for (make, batch_params) in components)
|
||||
distributions = [make(**batch_params)
|
||||
for (make, batch_params) in components_tensor_params]
|
||||
|
||||
# Store components internally with their batch params having been
|
||||
# converted to tensors.
|
||||
# TODO(ebrevdo): Use self._components to optimize sampling.
|
||||
self._components = components_tensor_params
|
||||
|
||||
if not all(isinstance(d, distribution.Distribution) for d in distributions):
|
||||
dtype = components[0].dtype
|
||||
if not all(d.dtype == dtype for d in components):
|
||||
raise TypeError("All components must have the same dtype, but saw "
|
||||
"dtypes: %s" % [(d.name, d.dtype) for d in components])
|
||||
is_continuous = components[0].is_continuous
|
||||
if not all(d.is_continuous == is_continuous for d in components):
|
||||
raise TypeError(
|
||||
"all entries in distributions must be instances of Distribution, "
|
||||
"but saw: %s" % distributions)
|
||||
|
||||
dtype = distributions[0].dtype
|
||||
if not all(d.dtype == dtype for d in distributions):
|
||||
raise TypeError("All distributions must have the same dtype, but saw "
|
||||
"dtypes: %s" % [(d.name, d.dtype) for d in distributions])
|
||||
is_continuous = distributions[0].is_continuous
|
||||
if not all(d.is_continuous == is_continuous for d in distributions):
|
||||
raise TypeError(
|
||||
"All distributions must either be continuous or not, but continuity "
|
||||
"values are: %s" % [(d.name, d.is_continuous) for d in distributions])
|
||||
static_event_shape = distributions[0].get_event_shape()
|
||||
"All components must either be continuous or not, but continuity "
|
||||
"values are: %s" % [(d.name, d.is_continuous) for d in components])
|
||||
static_event_shape = components[0].get_event_shape()
|
||||
static_batch_shape = cat.get_batch_shape()
|
||||
for d in distributions:
|
||||
for d in components:
|
||||
static_event_shape = static_event_shape.merge_with(d.get_event_shape())
|
||||
static_batch_shape = static_batch_shape.merge_with(d.get_batch_shape())
|
||||
if static_event_shape.ndims is None:
|
||||
raise ValueError(
|
||||
"Expected to know rank(event_shape) from distributions, but "
|
||||
"none of the distributions provide a static number of ndims")
|
||||
"Expected to know rank(event_shape) from components, but "
|
||||
"none of the components provide a static number of ndims")
|
||||
|
||||
# Ensure that all batch and event ndims are consistent.
|
||||
with ops.name_scope(name, values=[cat.logits]):
|
||||
@ -180,42 +126,42 @@ class Mixture(distribution.Distribution):
|
||||
if static_num_components is None:
|
||||
raise ValueError(
|
||||
"Could not infer number of classes from cat and unable "
|
||||
"to compare this value to the number of distributions passed in.")
|
||||
"to compare this value to the number of components passed in.")
|
||||
# Possibly convert from numpy 0-D array.
|
||||
static_num_components = int(static_num_components)
|
||||
if static_num_components != len(distributions):
|
||||
raise ValueError("cat.num_classes != len(distributions): %d vs. %d" %
|
||||
(static_num_components, len(distributions)))
|
||||
if static_num_components != len(components):
|
||||
raise ValueError("cat.num_classes != len(components): %d vs. %d" %
|
||||
(static_num_components, len(components)))
|
||||
|
||||
cat_batch_shape = cat.batch_shape()
|
||||
cat_batch_rank = array_ops.size(cat_batch_shape)
|
||||
if validate_args:
|
||||
batch_shapes = [d.batch_shape() for d in distributions]
|
||||
batch_shapes = [d.batch_shape() for d in components]
|
||||
batch_ranks = [array_ops.size(bs) for bs in batch_shapes]
|
||||
check_message = ("distributions[%d] batch shape must match cat "
|
||||
check_message = ("components[%d] batch shape must match cat "
|
||||
"batch shape")
|
||||
self._assertions = [
|
||||
check_ops.assert_equal(
|
||||
cat_batch_rank, batch_ranks[di], message=check_message % di)
|
||||
for di in range(len(distributions))
|
||||
for di in range(len(components))
|
||||
]
|
||||
self._assertions += [
|
||||
check_ops.assert_equal(
|
||||
cat_batch_shape, batch_shapes[di], message=check_message % di)
|
||||
for di in range(len(distributions))
|
||||
for di in range(len(components))
|
||||
]
|
||||
else:
|
||||
self._assertions = []
|
||||
|
||||
self._cat = cat
|
||||
self._distributions = list(distributions)
|
||||
self._components = list(components)
|
||||
self._num_components = static_num_components
|
||||
self._static_event_shape = static_event_shape
|
||||
self._static_batch_shape = static_batch_shape
|
||||
|
||||
super(Mixture, self).__init__(
|
||||
dtype=dtype,
|
||||
parameters={"cat": self._cat, "distributions": self._distributions,
|
||||
parameters={"cat": self._cat, "components": self._components,
|
||||
"num_components": self._num_components},
|
||||
is_reparameterized=False,
|
||||
is_continuous=is_continuous,
|
||||
@ -228,8 +174,8 @@ class Mixture(distribution.Distribution):
|
||||
return self._cat
|
||||
|
||||
@property
|
||||
def distributions(self):
|
||||
return self._distributions
|
||||
def components(self):
|
||||
return self._components
|
||||
|
||||
@property
|
||||
def num_components(self):
|
||||
@ -242,14 +188,14 @@ class Mixture(distribution.Distribution):
|
||||
return self._static_batch_shape
|
||||
|
||||
def _event_shape(self):
|
||||
return self._distributions[0].event_shape()
|
||||
return self._components[0].event_shape()
|
||||
|
||||
def _get_event_shape(self):
|
||||
return self._static_event_shape
|
||||
|
||||
def _mean(self):
|
||||
with ops.control_dependencies(self._assertions):
|
||||
distribution_means = [d.mean() for d in self.distributions]
|
||||
distribution_means = [d.mean() for d in self.components]
|
||||
cat_probs = self._cat_probs(log_probs=False)
|
||||
# This was checked to not be None at construction time.
|
||||
static_event_rank = self.get_event_shape().ndims
|
||||
@ -271,7 +217,7 @@ class Mixture(distribution.Distribution):
|
||||
def _log_prob(self, x):
|
||||
with ops.control_dependencies(self._assertions):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
distribution_log_probs = [d.log_prob(x) for d in self.distributions]
|
||||
distribution_log_probs = [d.log_prob(x) for d in self.components]
|
||||
cat_log_probs = self._cat_probs(log_probs=True)
|
||||
final_log_probs = [
|
||||
cat_lp + d_lp
|
||||
@ -351,7 +297,7 @@ class Mixture(distribution.Distribution):
|
||||
samples_class = [None for _ in range(self.num_components)]
|
||||
for c in range(self.num_components):
|
||||
n_class = array_ops.size(partitioned_samples_indices[c])
|
||||
samples_class_c = self.distributions[c].sample_n(n_class, seed=seed)
|
||||
samples_class_c = self.components[c].sample_n(n_class, seed=seed)
|
||||
|
||||
# Pull out the correct batch entries from each index.
|
||||
# To do this, we may have to flatten the batch shape.
|
||||
@ -395,7 +341,7 @@ class Mixture(distribution.Distribution):
|
||||
r"""A lower bound on the entropy of this mixture model.
|
||||
|
||||
The bound below is not always very tight, and its usefulness depends
|
||||
on the mixture probabilities and the distributions in use.
|
||||
on the mixture probabilities and the components in use.
|
||||
|
||||
A lower bound is useful for ELBO when the `Mixture` is the variational
|
||||
distribution:
|
||||
@ -432,7 +378,7 @@ class Mixture(distribution.Distribution):
|
||||
"""
|
||||
with self._name_scope(name, values=[self.cat.logits]):
|
||||
with ops.control_dependencies(self._assertions):
|
||||
distribution_entropies = [d.entropy() for d in self.distributions]
|
||||
distribution_entropies = [d.entropy() for d in self.components]
|
||||
cat_probs = self._cat_probs(log_probs=False)
|
||||
partial_entropies = [
|
||||
c_p * m for (c_p, m) in zip(cat_probs, distribution_entropies)
|
||||
|
@ -35,7 +35,7 @@ namespace {
|
||||
// The complete set of audio file formats that are supported by the op. These
|
||||
// strings are defined by FFmpeg and documented here:
|
||||
// https://www.ffmpeg.org/ffmpeg-formats.html
|
||||
const char* kValidFileFormats[] = {"mp3", "ogg", "wav"};
|
||||
const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"};
|
||||
|
||||
// Writes binary data to a file.
|
||||
Status WriteFile(const string& filename, tensorflow::StringPiece contents) {
|
||||
|
@ -61,10 +61,30 @@ class DecodeAudioOpTest(tf.test.TestCase):
|
||||
self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1)
|
||||
self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 2)
|
||||
|
||||
def testMonoMp4Mp3Codec(self):
|
||||
# mp3 compressed audio streams in mp4 container.
|
||||
self._loadFileAndTest('mono_16khz_mp3.mp4', 'mp4', 2.77, 20000, 1)
|
||||
self._loadFileAndTest('mono_16khz_mp3.mp4', 'mp4', 2.77, 20000, 2)
|
||||
|
||||
def testMonoMp4AacCodec(self):
|
||||
# aac compressed audio streams in mp4 container.
|
||||
self._loadFileAndTest('mono_32khz_aac.mp4', 'mp4', 2.77, 20000, 1)
|
||||
self._loadFileAndTest('mono_32khz_aac.mp4', 'mp4', 2.77, 20000, 2)
|
||||
|
||||
def testStereoMp3(self):
|
||||
self._loadFileAndTest('stereo_48khz.mp3', 'mp3', 0.79, 50000, 1)
|
||||
self._loadFileAndTest('stereo_48khz.mp3', 'mp3', 0.79, 20000, 2)
|
||||
|
||||
def testStereoMp4Mp3Codec(self):
|
||||
# mp3 compressed audio streams in mp4 container.
|
||||
self._loadFileAndTest('stereo_48khz_mp3.mp4', 'mp4', 0.79, 50000, 1)
|
||||
self._loadFileAndTest('stereo_48khz_mp3.mp4', 'mp4', 0.79, 20000, 2)
|
||||
|
||||
def testStereoMp4AacCodec(self):
|
||||
# aac compressed audio streams in mp4 container.
|
||||
self._loadFileAndTest('stereo_48khz_aac.mp4', 'mp4', 0.79, 50000, 1)
|
||||
self._loadFileAndTest('stereo_48khz_aac.mp4', 'mp4', 0.79, 20000, 2)
|
||||
|
||||
def testMonoWav(self):
|
||||
self._loadFileAndTest('mono_10khz.wav', 'wav', 0.57, 5000, 1)
|
||||
self._loadFileAndTest('mono_10khz.wav', 'wav', 0.57, 10000, 4)
|
||||
|
@ -35,11 +35,14 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
|
||||
channel_count=None):
|
||||
"""Create an op that decodes the contents of an audio file.
|
||||
|
||||
Note that ffmpeg is free to select the "best" audio track from an mp4.
|
||||
https://trac.ffmpeg.org/wiki/Map
|
||||
|
||||
Args:
|
||||
contents: The binary contents of the audio file to decode. This is a
|
||||
scalar.
|
||||
file_format: A string specifying which format the contents will conform
|
||||
to. This can be mp3, ogg, or wav.
|
||||
to. This can be mp3, mp4, ogg, or wav.
|
||||
samples_per_second: The number of samples per second that is assumed.
|
||||
In some cases, resampling will occur to generate the correct sample
|
||||
rate.
|
||||
|
BIN
tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3.mp4
vendored
Normal file
BIN
tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3.mp4
vendored
Normal file
Binary file not shown.
BIN
tensorflow/contrib/ffmpeg/testdata/mono_32khz_aac.mp4
vendored
Normal file
BIN
tensorflow/contrib/ffmpeg/testdata/mono_32khz_aac.mp4
vendored
Normal file
Binary file not shown.
BIN
tensorflow/contrib/ffmpeg/testdata/stereo_48khz_aac.mp4
vendored
Normal file
BIN
tensorflow/contrib/ffmpeg/testdata/stereo_48khz_aac.mp4
vendored
Normal file
Binary file not shown.
BIN
tensorflow/contrib/ffmpeg/testdata/stereo_48khz_mp3.mp4
vendored
Normal file
BIN
tensorflow/contrib/ffmpeg/testdata/stereo_48khz_mp3.mp4
vendored
Normal file
Binary file not shown.
@ -125,7 +125,8 @@ class ClassifierTest(tf.test.TestCase):
|
||||
default_signature = signatures.default_signature
|
||||
return default_signature
|
||||
|
||||
def testExportMonitorRegressionSignature(self):
|
||||
# Disable this test case until b/31032996 is fixed.
|
||||
def _testExportMonitorRegressionSignature(self):
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
est = tf.contrib.learn.Classifier(model_fn=logistic_model_fn, n_classes=3)
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
|
@ -19,11 +19,269 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework import deprecated_arg_values
|
||||
from tensorflow.contrib.framework import list_variables
|
||||
from tensorflow.contrib.framework import load_variable
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.layers.python.layers import optimizers
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn import session_run_hook
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.utils import checkpoints
|
||||
from tensorflow.contrib.learn.python.learn.utils import export
|
||||
from tensorflow.contrib.losses.python.losses import loss_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import training as train
|
||||
|
||||
|
||||
class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
|
||||
_CENTERED_BIAS = "centered_bias"
|
||||
_CENTERED_BIAS_WEIGHT = "centered_bias_weight"
|
||||
_CLASSES = "classes"
|
||||
_LOGISTIC = "logistic"
|
||||
_PROBABILITIES = "probabilities"
|
||||
|
||||
# The default learning rate of 0.05 is a historical artifact of the initial
|
||||
# implementation, but seems a reasonable choice.
|
||||
_LEARNING_RATE = 0.05
|
||||
|
||||
|
||||
def _as_iterable(preds, output):
|
||||
for pred in preds:
|
||||
yield pred[output]
|
||||
|
||||
|
||||
def _get_feature_dict(features):
|
||||
if isinstance(features, dict):
|
||||
return features
|
||||
return {"": features}
|
||||
|
||||
|
||||
def _get_optimizer(optimizer):
|
||||
if callable(optimizer):
|
||||
return optimizer()
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
|
||||
def _add_hidden_layer_summary(value, tag):
|
||||
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
|
||||
nn.zero_fraction(value))
|
||||
logging_ops.histogram_summary("%s:activation" % tag, value)
|
||||
|
||||
|
||||
def _centered_bias(num_label_columns):
|
||||
centered_bias = variables.Variable(
|
||||
array_ops.zeros([num_label_columns]),
|
||||
collections=[_CENTERED_BIAS, ops.GraphKeys.VARIABLES],
|
||||
name=_CENTERED_BIAS_WEIGHT)
|
||||
logging_ops.scalar_summary(
|
||||
["centered_bias %d" % cb for cb in range(num_label_columns)],
|
||||
array_ops.reshape(centered_bias, [-1]))
|
||||
return centered_bias
|
||||
|
||||
|
||||
def _centered_bias_step(targets, loss_fn, num_label_columns):
|
||||
centered_bias = ops.get_collection(_CENTERED_BIAS)
|
||||
batch_size = array_ops.shape(targets)[0]
|
||||
logits = array_ops.reshape(
|
||||
array_ops.tile(centered_bias[0], [batch_size]),
|
||||
[batch_size, num_label_columns])
|
||||
loss = loss_fn(logits, targets)
|
||||
return train.AdagradOptimizer(0.1).minimize(loss, var_list=centered_bias)
|
||||
|
||||
|
||||
def _get_weight_tensor(features, weight_column_name):
|
||||
"""Returns the weight tensor of shape [batch_size] or 1."""
|
||||
if weight_column_name is None:
|
||||
return 1.0
|
||||
else:
|
||||
return array_ops.reshape(
|
||||
math_ops.to_float(features[weight_column_name]),
|
||||
shape=(-1,))
|
||||
|
||||
|
||||
def _rescale_eval_loss(loss, weights):
|
||||
"""Rescales evaluation loss according to the given weights.
|
||||
|
||||
The rescaling is needed because in the training loss weights are not
|
||||
considered in the denominator, whereas for the evaluation loss we should
|
||||
divide by the sum of weights.
|
||||
|
||||
The rescaling factor is:
|
||||
R = sum_{i} 1 / sum_{i} w_{i}
|
||||
|
||||
Args:
|
||||
loss: the scalar weighted loss.
|
||||
weights: weight coefficients. Either a scalar, or a `Tensor` of shape
|
||||
[batch_size].
|
||||
|
||||
Returns:
|
||||
The given loss multiplied by the rescaling factor.
|
||||
"""
|
||||
rescaling_factor = math_ops.reduce_mean(weights)
|
||||
return math_ops.div(loss, rescaling_factor)
|
||||
|
||||
|
||||
def _predictions(logits, n_classes):
|
||||
"""Returns predictions for the given logits and n_classes."""
|
||||
predictions = {}
|
||||
if n_classes == 2:
|
||||
predictions[_LOGISTIC] = math_ops.sigmoid(logits)
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[_PROBABILITIES] = nn.softmax(logits)
|
||||
predictions[_CLASSES] = array_ops.reshape(
|
||||
math_ops.argmax(logits, 1), shape=(-1, 1))
|
||||
return predictions
|
||||
|
||||
|
||||
def _dnn_classifier_model_fn(features, targets, mode, params):
|
||||
"""Deep Neural Net model_fn.
|
||||
|
||||
Args:
|
||||
features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`).
|
||||
targets: `Tensor` of shape [batch_size, 1] or [batch_size] target labels of
|
||||
dtype `int32` or `int64` in the range `[0, n_classes)`.
|
||||
mode: Defines whether this is training, evaluation or prediction.
|
||||
See `ModeKeys`.
|
||||
params: A dict of hyperparameters.
|
||||
The following hyperparameters are expected:
|
||||
* hidden_units: List of hidden units per layer.
|
||||
* feature_columns: An iterable containing all the feature columns used by
|
||||
the model.
|
||||
* n_classes: number of target classes.
|
||||
* weight_column_name: A string defining the weight feature column, or
|
||||
None if there are no weights.
|
||||
* optimizer: string, `Optimizer` object, or callable that defines the
|
||||
optimizer to use for training.
|
||||
* activation_fn: Activation function applied to each layer. If `None`,
|
||||
will use `tf.nn.relu`.
|
||||
* dropout: When not `None`, the probability we will drop out a given
|
||||
coordinate.
|
||||
* gradient_clip_norm: A float > 0. If provided, gradients are
|
||||
clipped to their global norm with this clipping ratio.
|
||||
* enable_centered_bias: A bool. If True, estimator will learn a centered
|
||||
bias variable for each class. Rest of the model structure learns the
|
||||
residual after centered bias.
|
||||
* num_ps_replicas: The number of parameter server replicas.
|
||||
|
||||
Returns:
|
||||
predictions: A dict of `Tensor` objects.
|
||||
loss: A scalar containing the loss of the step.
|
||||
train_op: The op for training.
|
||||
"""
|
||||
hidden_units = params["hidden_units"]
|
||||
feature_columns = params["feature_columns"]
|
||||
n_classes = params["n_classes"]
|
||||
weight_column_name = params["weight_column_name"]
|
||||
optimizer = params["optimizer"]
|
||||
activation_fn = params["activation_fn"]
|
||||
dropout = params["dropout"]
|
||||
gradient_clip_norm = params["gradient_clip_norm"]
|
||||
enable_centered_bias = params["enable_centered_bias"]
|
||||
num_ps_replicas = params["num_ps_replicas"]
|
||||
|
||||
features = _get_feature_dict(features)
|
||||
parent_scope = "dnn"
|
||||
num_label_columns = 1 if n_classes == 2 else n_classes
|
||||
if n_classes == 2:
|
||||
loss_fn = loss_ops.sigmoid_cross_entropy
|
||||
else:
|
||||
loss_fn = loss_ops.sparse_softmax_cross_entropy
|
||||
|
||||
input_layer_partitioner = (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=num_ps_replicas,
|
||||
min_slice_size=64 << 20))
|
||||
with variable_scope.variable_scope(
|
||||
parent_scope + "/input_from_feature_columns",
|
||||
values=features.values(),
|
||||
partitioner=input_layer_partitioner) as scope:
|
||||
net = layers.input_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
feature_columns=feature_columns,
|
||||
weight_collections=[parent_scope],
|
||||
scope=scope)
|
||||
|
||||
hidden_layer_partitioner = (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=num_ps_replicas))
|
||||
for layer_id, num_hidden_units in enumerate(hidden_units):
|
||||
with variable_scope.variable_scope(
|
||||
parent_scope + "/hiddenlayer_%d" % layer_id,
|
||||
values=[net],
|
||||
partitioner=hidden_layer_partitioner) as scope:
|
||||
net = layers.fully_connected(
|
||||
net,
|
||||
num_hidden_units,
|
||||
activation_fn=activation_fn,
|
||||
variables_collections=[parent_scope],
|
||||
scope=scope)
|
||||
if dropout is not None and mode == estimator.ModeKeys.TRAIN:
|
||||
net = layers.dropout(
|
||||
net,
|
||||
keep_prob=(1.0 - dropout))
|
||||
_add_hidden_layer_summary(net, scope.name)
|
||||
|
||||
with variable_scope.variable_scope(
|
||||
parent_scope + "/logits",
|
||||
values=[net],
|
||||
partitioner=hidden_layer_partitioner) as scope:
|
||||
logits = layers.fully_connected(
|
||||
net,
|
||||
num_label_columns,
|
||||
activation_fn=None,
|
||||
variables_collections=[parent_scope],
|
||||
scope=scope)
|
||||
_add_hidden_layer_summary(logits, scope.name)
|
||||
|
||||
if enable_centered_bias:
|
||||
logits = nn.bias_add(logits, _centered_bias(num_label_columns))
|
||||
|
||||
if mode == estimator.ModeKeys.TRAIN:
|
||||
loss = loss_fn(logits, targets,
|
||||
weight=_get_weight_tensor(features, weight_column_name))
|
||||
|
||||
train_ops = [optimizers.optimize_loss(
|
||||
loss=loss, global_step=contrib_variables.get_global_step(),
|
||||
learning_rate=_LEARNING_RATE, optimizer=_get_optimizer(optimizer),
|
||||
clip_gradients=gradient_clip_norm, name=parent_scope)]
|
||||
if enable_centered_bias:
|
||||
train_ops.append(_centered_bias_step(targets, loss_fn, num_label_columns))
|
||||
|
||||
return None, loss, control_flow_ops.group(*train_ops)
|
||||
|
||||
elif mode == estimator.ModeKeys.EVAL:
|
||||
predictions = _predictions(logits=logits, n_classes=n_classes)
|
||||
|
||||
weight = _get_weight_tensor(features, weight_column_name)
|
||||
training_loss = loss_fn(logits, targets, weight=weight)
|
||||
loss = _rescale_eval_loss(training_loss, weight)
|
||||
|
||||
return predictions, loss, []
|
||||
|
||||
else: # mode == estimator.ModeKeys.INFER:
|
||||
predictions = _predictions(logits=logits, n_classes=n_classes)
|
||||
|
||||
return predictions, None, []
|
||||
|
||||
|
||||
class DNNClassifier(evaluable.Evaluable, trainable.Trainable):
|
||||
"""A classifier for TensorFlow DNN models.
|
||||
|
||||
Example:
|
||||
@ -124,36 +382,211 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
|
||||
|
||||
Returns:
|
||||
A `DNNClassifier` estimator.
|
||||
|
||||
Raises:
|
||||
ValueError: If `n_classes` < 2.
|
||||
"""
|
||||
if enable_centered_bias is None:
|
||||
enable_centered_bias = True
|
||||
dnn_linear_combined._changing_default_center_bias() # pylint: disable=protected-access
|
||||
super(DNNClassifier, self).__init__(
|
||||
model_dir=model_dir,
|
||||
n_classes=n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
dnn_feature_columns=feature_columns,
|
||||
dnn_optimizer=optimizer,
|
||||
dnn_hidden_units=hidden_units,
|
||||
dnn_activation_fn=activation_fn,
|
||||
dnn_dropout=dropout,
|
||||
gradient_clip_norm=gradient_clip_norm,
|
||||
enable_centered_bias=enable_centered_bias,
|
||||
config=config)
|
||||
self.feature_columns = feature_columns
|
||||
self.optimizer = optimizer
|
||||
self.activation_fn = activation_fn
|
||||
self.dropout = dropout
|
||||
self.hidden_units = hidden_units
|
||||
self._feature_columns_inferred = False
|
||||
self._hidden_units = hidden_units
|
||||
self._feature_columns = feature_columns
|
||||
self._model_dir = model_dir or tempfile.mkdtemp()
|
||||
if n_classes <= 1:
|
||||
raise ValueError(
|
||||
"Classification requires n_classes >= 2. Given: {}".format(n_classes))
|
||||
self._n_classes = n_classes
|
||||
self._weight_column_name = weight_column_name
|
||||
optimizer = optimizer or "Adagrad"
|
||||
num_ps_replicas = config.num_ps_replicas if config else 0
|
||||
|
||||
self._estimator = estimator.Estimator(
|
||||
model_fn=_dnn_classifier_model_fn,
|
||||
model_dir=self._model_dir,
|
||||
config=config,
|
||||
params={
|
||||
"hidden_units": hidden_units,
|
||||
"feature_columns": feature_columns,
|
||||
"n_classes": n_classes,
|
||||
"weight_column_name": weight_column_name,
|
||||
"optimizer": optimizer,
|
||||
"activation_fn": activation_fn,
|
||||
"dropout": dropout,
|
||||
"gradient_clip_norm": gradient_clip_norm,
|
||||
"enable_centered_bias": enable_centered_bias,
|
||||
"num_ps_replicas": num_ps_replicas,
|
||||
})
|
||||
|
||||
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
||||
monitors=None, max_steps=None):
|
||||
"""See trainable.Trainable."""
|
||||
# TODO(roumposg): Remove when deprecated monitors are removed.
|
||||
if monitors is not None:
|
||||
deprecated_monitors = [
|
||||
m for m in monitors
|
||||
if not isinstance(m, session_run_hook.SessionRunHook)
|
||||
]
|
||||
for monitor in deprecated_monitors:
|
||||
monitor.set_estimator(self)
|
||||
monitor._lock_estimator() # pylint: disable=protected-access
|
||||
|
||||
result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
|
||||
batch_size=batch_size, monitors=monitors,
|
||||
max_steps=max_steps)
|
||||
|
||||
if monitors is not None:
|
||||
for monitor in deprecated_monitors:
|
||||
monitor._unlock_estimator() # pylint: disable=protected-access
|
||||
|
||||
return result
|
||||
|
||||
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
|
||||
batch_size=None, steps=None, metrics=None, name=None):
|
||||
"""See evaluable.Evaluable."""
|
||||
if metrics is None:
|
||||
metrics = {}
|
||||
metrics.update({
|
||||
"accuracy": metric_spec.MetricSpec(
|
||||
metric_fn=metrics_lib.streaming_accuracy,
|
||||
prediction_key=_CLASSES,
|
||||
weight_key=self._weight_column_name)})
|
||||
if self._n_classes == 2:
|
||||
metrics.update({
|
||||
"auc": metric_spec.MetricSpec(
|
||||
metric_fn=metrics_lib.streaming_auc,
|
||||
prediction_key=_LOGISTIC,
|
||||
weight_key=self._weight_column_name)})
|
||||
return self._estimator.evaluate(
|
||||
x=x, y=y, input_fn=input_fn, feed_fn=feed_fn, batch_size=batch_size,
|
||||
steps=steps, metrics=metrics, name=name)
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
as_iterable=False)
|
||||
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=False):
|
||||
"""Returns predicted classes for given features.
|
||||
|
||||
Args:
|
||||
x: features.
|
||||
input_fn: Input function. If set, x must be None.
|
||||
batch_size: Override default batch size.
|
||||
as_iterable: If True, return an iterable which keeps yielding predictions
|
||||
for each example until inputs are exhausted. Note: The inputs must
|
||||
terminate if you want the iterable to terminate (e.g. be sure to pass
|
||||
num_epochs=1 if you are using something like read_batch_features).
|
||||
|
||||
Returns:
|
||||
Numpy array of predicted classes (or an iterable of predicted classes if
|
||||
as_iterable is True).
|
||||
"""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size, outputs=[_CLASSES],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=_CLASSES)
|
||||
return preds[_CLASSES].reshape(-1)
|
||||
|
||||
@deprecated_arg_values(
|
||||
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
|
||||
as_iterable=False)
|
||||
def predict_proba(
|
||||
self, x=None, input_fn=None, batch_size=None, as_iterable=False):
|
||||
"""Returns prediction probabilities for given features.
|
||||
|
||||
Args:
|
||||
x: features.
|
||||
input_fn: Input function. If set, x and y must be None.
|
||||
batch_size: Override default batch size.
|
||||
as_iterable: If True, return an iterable which keeps yielding predictions
|
||||
for each example until inputs are exhausted. Note: The inputs must
|
||||
terminate if you want the iterable to terminate (e.g. be sure to pass
|
||||
num_epochs=1 if you are using something like read_batch_features).
|
||||
|
||||
Returns:
|
||||
Numpy array of predicted probabilities (or an iterable of predicted
|
||||
probabilities if as_iterable is True).
|
||||
"""
|
||||
preds = self._estimator.predict(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size,
|
||||
outputs=[_PROBABILITIES],
|
||||
as_iterable=as_iterable)
|
||||
if as_iterable:
|
||||
return _as_iterable(preds, output=_PROBABILITIES)
|
||||
return preds[_PROBABILITIES]
|
||||
|
||||
def get_variable_names(self):
|
||||
"""Returns list of all variable names in this model.
|
||||
|
||||
Returns:
|
||||
List of names.
|
||||
"""
|
||||
return [name for name, _ in list_variables(self._model_dir)]
|
||||
|
||||
def get_variable_value(self, name):
|
||||
"""Returns value of the variable given by name.
|
||||
|
||||
Args:
|
||||
name: string, name of the tensor.
|
||||
|
||||
Returns:
|
||||
`Tensor` object.
|
||||
"""
|
||||
return load_variable(self._model_dir, name)
|
||||
|
||||
def export(self,
|
||||
export_dir,
|
||||
input_fn=None,
|
||||
input_feature_key=None,
|
||||
use_deprecated_input_fn=True,
|
||||
signature_fn=None,
|
||||
default_batch_size=1,
|
||||
exports_to_keep=None):
|
||||
"""See BasEstimator.export."""
|
||||
def default_input_fn(unused_estimator, examples):
|
||||
return layers.parse_feature_columns_from_examples(
|
||||
examples, self._feature_columns)
|
||||
self._estimator.export(
|
||||
export_dir=export_dir,
|
||||
input_fn=input_fn or default_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
use_deprecated_input_fn=use_deprecated_input_fn,
|
||||
signature_fn=(
|
||||
signature_fn or export.classification_signature_fn_with_prob),
|
||||
prediction_key=_PROBABILITIES,
|
||||
default_batch_size=default_batch_size,
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
return self._model_dir
|
||||
|
||||
@property
|
||||
@deprecated("2016-10-13", "This method inspects the private state of the "
|
||||
"object, and should not be used")
|
||||
def weights_(self):
|
||||
return self.dnn_weights_
|
||||
hiddenlayer_weights = [checkpoints.load_variable(
|
||||
self._model_dir, name=("dnn/hiddenlayer_%d/weights" % i))
|
||||
for i, _ in enumerate(self._hidden_units)]
|
||||
logits_weights = [checkpoints.load_variable(
|
||||
self._model_dir, name="dnn/logits/weights")]
|
||||
return hiddenlayer_weights + logits_weights
|
||||
|
||||
@property
|
||||
@deprecated("2016-10-13", "This method inspects the private state of the "
|
||||
"object, and should not be used")
|
||||
def bias_(self):
|
||||
return self.dnn_bias_
|
||||
hiddenlayer_bias = [checkpoints.load_variable(
|
||||
self._model_dir, name=("dnn/hiddenlayer_%d/biases" % i))
|
||||
for i, _ in enumerate(self._hidden_units)]
|
||||
logits_bias = [checkpoints.load_variable(
|
||||
self._model_dir, name="dnn/logits/biases")]
|
||||
centered_bias = [checkpoints.load_variable(
|
||||
self._model_dir, name=_CENTERED_BIAS_WEIGHT)]
|
||||
return hiddenlayer_bias + logits_bias + centered_bias
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._estimator.config
|
||||
|
||||
|
||||
class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
|
||||
|
@ -27,13 +27,8 @@ import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
from sklearn.cross_validation import cross_val_score
|
||||
HAS_SKLEARN = True
|
||||
except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def _prepare_iris_data_for_logistic_regression():
|
||||
@ -350,6 +345,7 @@ class DNNClassifierTest(tf.test.TestCase):
|
||||
# For the case of binary classification, the 2nd column of "predictions"
|
||||
# denotes the model predictions.
|
||||
predictions = tf.slice(predictions, [0, 1], [-1, 1])
|
||||
targets = math_ops.cast(targets, predictions.dtype)
|
||||
return tf.reduce_sum(tf.mul(predictions, targets))
|
||||
|
||||
classifier = tf.contrib.learn.DNNClassifier(
|
||||
@ -362,9 +358,15 @@ class DNNClassifierTest(tf.test.TestCase):
|
||||
input_fn=_input_fn_train,
|
||||
steps=100,
|
||||
metrics={
|
||||
'my_accuracy': tf.contrib.metrics.streaming_accuracy,
|
||||
('my_precision', 'classes'): tf.contrib.metrics.streaming_precision,
|
||||
('my_metric', 'probabilities'): _my_metric_op
|
||||
'my_accuracy': MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||
prediction_key='classes'),
|
||||
'my_precision': MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||
prediction_key='classes'),
|
||||
'my_metric': MetricSpec(
|
||||
metric_fn=_my_metric_op,
|
||||
prediction_key='probabilities')
|
||||
})
|
||||
self.assertTrue(
|
||||
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
|
||||
@ -375,21 +377,14 @@ class DNNClassifierTest(tf.test.TestCase):
|
||||
|
||||
# Test the case where the 2nd element of the key is neither "classes" nor
|
||||
# "probabilities".
|
||||
with self.assertRaises(ValueError):
|
||||
classifier.evaluate(
|
||||
input_fn=_input_fn_train,
|
||||
steps=100,
|
||||
metrics={('bad_name', 'bad_type'): tf.contrib.metrics.streaming_auc})
|
||||
|
||||
# Test the case where the tuple of the key doesn't have 2 elements.
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaisesRegexp(KeyError, 'bad_type'):
|
||||
classifier.evaluate(
|
||||
input_fn=_input_fn_train,
|
||||
steps=100,
|
||||
metrics={
|
||||
('bad_length_name', 'classes', 'bad_length'):
|
||||
tf.contrib.metrics.streaming_accuracy
|
||||
})
|
||||
'bad_name': MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_auc,
|
||||
prediction_key='bad_type')})
|
||||
|
||||
def testTrainSaveLoad(self):
|
||||
"""Tests that insures you can save and reload a trained model."""
|
||||
@ -466,6 +461,31 @@ class DNNClassifierTest(tf.test.TestCase):
|
||||
self.assertGreater(scores['accuracy'], 0.9)
|
||||
self.assertLess(scores['loss'], 0.3)
|
||||
|
||||
def testExport(self):
|
||||
"""Tests export model for servo."""
|
||||
|
||||
def input_fn():
|
||||
return {
|
||||
'age': tf.constant([1]),
|
||||
'language': tf.SparseTensor(values=['english'],
|
||||
indices=[[0, 0]],
|
||||
shape=[1, 1])
|
||||
}, tf.constant([[1]])
|
||||
|
||||
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
|
||||
feature_columns = [
|
||||
tf.contrib.layers.real_valued_column('age'),
|
||||
tf.contrib.layers.embedding_column(language, dimension=1)
|
||||
]
|
||||
|
||||
classifier = tf.contrib.learn.DNNClassifier(
|
||||
feature_columns=feature_columns,
|
||||
hidden_units=[3, 3])
|
||||
classifier.fit(input_fn=input_fn, steps=100)
|
||||
|
||||
export_dir = tempfile.mkdtemp()
|
||||
classifier.export(export_dir)
|
||||
|
||||
def testDisableCenteredBias(self):
|
||||
"""Tests that we can disable centered bias."""
|
||||
cont_features = [
|
||||
@ -484,32 +504,6 @@ class DNNClassifierTest(tf.test.TestCase):
|
||||
self.assertGreater(scores['accuracy'], 0.8)
|
||||
self.assertLess(scores['loss'], 0.3)
|
||||
|
||||
def testSklearnCompatibility(self):
|
||||
"""Tests compatibility with sklearn"""
|
||||
if not HAS_SKLEARN:
|
||||
return
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
|
||||
cont_features = [
|
||||
tf.contrib.layers.real_valued_column('', dimension=4)]
|
||||
kwargs = {
|
||||
'n_classes': 3,
|
||||
'feature_columns': cont_features,
|
||||
'optimizer' : 'Adam',
|
||||
'hidden_units' : [3, 4]
|
||||
}
|
||||
|
||||
classifier = tf.contrib.learn.DNNClassifier(**kwargs)
|
||||
|
||||
scores = cross_val_score(
|
||||
classifier,
|
||||
iris.data[1:5],
|
||||
iris.target[1:5],
|
||||
scoring='accuracy',
|
||||
fit_params={'steps': 100}
|
||||
)
|
||||
self.assertAllClose(scores, [1, 1, 1])
|
||||
|
||||
|
||||
class DNNRegressorTest(tf.test.TestCase):
|
||||
|
||||
|
@ -234,7 +234,7 @@ def read_keyed_batch_features(file_pattern,
|
||||
queue_capacity=10000,
|
||||
reader_num_threads=1,
|
||||
feature_queue_capacity=100,
|
||||
num_enqueue_threads=2,
|
||||
num_queue_runners=2,
|
||||
parser_num_threads=None,
|
||||
parse_fn=None,
|
||||
name=None):
|
||||
@ -266,8 +266,8 @@ def read_keyed_batch_features(file_pattern,
|
||||
queue_capacity: Capacity for input queue.
|
||||
reader_num_threads: The number of threads to read examples.
|
||||
feature_queue_capacity: Capacity of the parsed features queue.
|
||||
num_enqueue_threads: Number of threads to enqueue the parsed example queue.
|
||||
Using multiple threads to enqueue the parsed example queue helps maintain
|
||||
num_queue_runners: Number of queue runners to start for the feature queue,
|
||||
Adding multiple queue runners for the parsed example queue helps maintain
|
||||
a full queue when the subsequent computations overall are cheaper than
|
||||
parsing.
|
||||
parser_num_threads: (Deprecated) The number of threads to parse examples.
|
||||
@ -300,14 +300,14 @@ def read_keyed_batch_features(file_pattern,
|
||||
feature_map,
|
||||
keys=keys,
|
||||
feature_queue_capacity=feature_queue_capacity,
|
||||
num_enqueue_threads=num_enqueue_threads,
|
||||
num_queue_runners=num_queue_runners,
|
||||
name=scope)
|
||||
|
||||
|
||||
def queue_parsed_features(parsed_features,
|
||||
keys=None,
|
||||
feature_queue_capacity=100,
|
||||
num_enqueue_threads=2,
|
||||
num_queue_runners=2,
|
||||
name=None):
|
||||
"""Speeds up parsing by using queues to do it asynchronously.
|
||||
|
||||
@ -326,8 +326,8 @@ def queue_parsed_features(parsed_features,
|
||||
parsed_features: A dict of string key to `Tensor` or `SparseTensor` objects.
|
||||
keys: `Tensor` of string keys.
|
||||
feature_queue_capacity: Capacity of the parsed features queue.
|
||||
num_enqueue_threads: Number of threads to enqueue the parsed example queue.
|
||||
Using multiple thrads to enqueue the parsed example queue helps maintain
|
||||
num_queue_runners: Number of queue runners to start for the feature queue,
|
||||
Adding multiple queue runners for the parsed example queue helps maintain
|
||||
a full queue when the subsequent computations overall are cheaper than
|
||||
parsing.
|
||||
name: Name of resulting op.
|
||||
@ -374,14 +374,14 @@ def queue_parsed_features(parsed_features,
|
||||
math_ops.cast(input_queue.size(), dtypes.float32)
|
||||
* (1. / feature_queue_capacity))
|
||||
|
||||
# Use multiple threads to enqueue so the queue is always full. Adding more
|
||||
# than two threads may hog the cpu on the worker to fill up the queue.
|
||||
enqueue_ops = [input_queue.enqueue(tensors_to_enqueue)
|
||||
for _ in range(num_enqueue_threads)]
|
||||
queue_runner.add_queue_runner(queue_runner.QueueRunner(
|
||||
input_queue, enqueue_ops,
|
||||
queue_closed_exception_types=(errors.OutOfRangeError,
|
||||
errors.CancelledError)))
|
||||
# Add multiple queue runners so that the queue is always full. Adding more
|
||||
# than two queue-runners may hog the cpu on the worker to fill up the queue.
|
||||
for _ in range(num_queue_runners):
|
||||
queue_runner.add_queue_runner(
|
||||
queue_runner.QueueRunner(
|
||||
input_queue, [input_queue.enqueue(tensors_to_enqueue)],
|
||||
queue_closed_exception_types=(errors.OutOfRangeError,
|
||||
errors.CancelledError)))
|
||||
|
||||
dequeued_tensors = input_queue.dequeue()
|
||||
|
||||
|
@ -83,8 +83,6 @@ def run(experiment_fn, output_dir, schedule=None):
|
||||
# Get the schedule
|
||||
config = experiment.estimator.config
|
||||
schedule = schedule or _get_default_schedule(config)
|
||||
if not schedule:
|
||||
raise ValueError('Must specify a schedule')
|
||||
|
||||
# Execute the schedule
|
||||
if not hasattr(experiment, schedule):
|
||||
@ -107,19 +105,36 @@ def run(experiment_fn, output_dir, schedule=None):
|
||||
return task()
|
||||
|
||||
|
||||
def _is_distributed(config):
|
||||
"""Returns true if this is a distributed job."""
|
||||
if not config.cluster_spec:
|
||||
return False
|
||||
|
||||
# This is considered a distributed job if there is more than one task
|
||||
# in the cluster spec.
|
||||
task_count = 0
|
||||
for job in config.cluster_spec.jobs:
|
||||
for _ in config.cluster_spec.job_tasks(job):
|
||||
task_count += 1
|
||||
|
||||
return task_count > 1
|
||||
|
||||
|
||||
def _get_default_schedule(config):
|
||||
"""Returns the default schedule for the provided RunConfig."""
|
||||
if not config or not config.job_name:
|
||||
return None
|
||||
if not config or not _is_distributed(config):
|
||||
return 'local_run'
|
||||
|
||||
if not config.job_name or config.job_name == 'master':
|
||||
# TODO(rhaertel): handle the case there are more
|
||||
# than one masters or explicitly disallow.
|
||||
if not config.job_name:
|
||||
raise ValueError('Must specify a schedule')
|
||||
|
||||
if config.job_name == 'master':
|
||||
# TODO(rhaertel): handle the case where there is more than one master
|
||||
# or explicitly disallow such a case.
|
||||
return 'local_run'
|
||||
elif config.job_name == 'ps':
|
||||
return 'run_std_server'
|
||||
elif config.job_name == 'worker':
|
||||
return 'train'
|
||||
|
||||
return ValueError('No default schedule for task type: %s' %
|
||||
(config.job_name,))
|
||||
raise ValueError('No default schedule for task type: %s' % (config.job_name,))
|
||||
|
@ -335,7 +335,7 @@ def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn, bidirectional,
|
||||
fw_cell, attn_length=attn_length, attn_size=attn_size,
|
||||
attn_vec_size=attn_vec_size, state_is_tuple=False)
|
||||
bw_cell = contrib_rnn.AttentionCellWrapper(
|
||||
fw_cell, attn_length=attn_length, attn_size=attn_size,
|
||||
bw_cell, attn_length=attn_length, attn_size=attn_size,
|
||||
attn_vec_size=attn_vec_size, state_is_tuple=False)
|
||||
rnn_fw_cell = nn.rnn_cell.MultiRNNCell([fw_cell] * num_layers,
|
||||
state_is_tuple=False)
|
||||
|
@ -39,7 +39,7 @@ class TestExperiment(tf.contrib.learn.Experiment):
|
||||
return Estimator()
|
||||
|
||||
def local_run(self):
|
||||
return "train_and_evaluate"
|
||||
return "local_run"
|
||||
|
||||
def train(self):
|
||||
return "train"
|
||||
@ -62,6 +62,18 @@ def build_non_experiment(output_dir):
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
|
||||
def build_distributed_cluster_spec():
|
||||
return tf.train.ClusterSpec(
|
||||
{"ps": ["localhost:1234", "localhost:1235"],
|
||||
"worker": ["localhost:1236", "localhost:1237"],
|
||||
"master": ["localhost:1238"],
|
||||
"foo_has_no_default_schedule": ["localhost:1239"]})
|
||||
|
||||
|
||||
def build_non_distributed_cluster_spec():
|
||||
return tf.train.ClusterSpec({"foo": ["localhost:1234"]})
|
||||
|
||||
|
||||
class MainTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -76,7 +88,9 @@ class MainTest(tf.test.TestCase):
|
||||
schedule="simple_task"))
|
||||
|
||||
def test_schedule_from_tf_config(self):
|
||||
os.environ["TF_CONFIG"] = json.dumps({"task": {"type": "worker"}})
|
||||
os.environ["TF_CONFIG"] = json.dumps(
|
||||
{"cluster": build_distributed_cluster_spec().as_dict(),
|
||||
"task": {"type": "worker"}})
|
||||
# RunConfig constructuor will set job_name from TF_CONFIG.
|
||||
config = run_config.RunConfig()
|
||||
self.assertEqual(
|
||||
@ -85,28 +99,35 @@ class MainTest(tf.test.TestCase):
|
||||
output_dir="/tmp"))
|
||||
|
||||
def test_schedule_from_manually_specified_job_name(self):
|
||||
config = run_config.RunConfig(job_name="worker")
|
||||
config = run_config.RunConfig(
|
||||
job_name="worker", cluster_spec=build_distributed_cluster_spec())
|
||||
self.assertEqual(
|
||||
"train",
|
||||
learn_runner.run(lambda output_dir: TestExperiment(config=config),
|
||||
output_dir="/tmp"))
|
||||
|
||||
def test_schedule_from_config_runs_train_and_evaluate_on_master(self):
|
||||
config = run_config.RunConfig(job_name="master", task=0, is_chief=True)
|
||||
def test_schedule_from_config_runs_local_run_on_master(self):
|
||||
config = run_config.RunConfig(
|
||||
job_name="master",
|
||||
cluster_spec=build_distributed_cluster_spec(),
|
||||
task=0,
|
||||
is_chief=True)
|
||||
self.assertEqual(
|
||||
"train_and_evaluate",
|
||||
"local_run",
|
||||
learn_runner.run(lambda output_dir: TestExperiment(config=config),
|
||||
output_dir="/tmp"))
|
||||
|
||||
def test_schedule_from_config_runs_serve_on_ps(self):
|
||||
config = run_config.RunConfig(job_name="ps")
|
||||
config = run_config.RunConfig(
|
||||
job_name="ps", cluster_spec=build_distributed_cluster_spec())
|
||||
self.assertEqual(
|
||||
"run_std_server",
|
||||
learn_runner.run(lambda output_dir: TestExperiment(config=config),
|
||||
output_dir="/tmp"))
|
||||
|
||||
def test_schedule_from_config_runs_train_on_worker(self):
|
||||
config = run_config.RunConfig(job_name="worker")
|
||||
config = run_config.RunConfig(
|
||||
job_name="worker", cluster_spec=build_distributed_cluster_spec())
|
||||
self.assertEqual(
|
||||
"train",
|
||||
learn_runner.run(lambda output_dir: TestExperiment(config=config),
|
||||
@ -117,13 +138,27 @@ class MainTest(tf.test.TestCase):
|
||||
learn_runner.run, build_experiment, "",
|
||||
"simple_task")
|
||||
|
||||
def test_fail_no_schedule_and_no_config(self):
|
||||
self.assertRaisesRegexp(ValueError, "Must specify a schedule",
|
||||
learn_runner.run, build_experiment, "/tmp")
|
||||
def test_no_schedule_and_no_config_runs_local_run(self):
|
||||
self.assertEqual(
|
||||
"local_run",
|
||||
learn_runner.run(build_experiment,
|
||||
output_dir="/tmp"))
|
||||
|
||||
def test_no_schedule_and_non_distributed_runs_local_run(self):
|
||||
config = run_config.RunConfig(
|
||||
cluster_spec=build_non_distributed_cluster_spec())
|
||||
self.assertEqual(
|
||||
"local_run",
|
||||
learn_runner.run(lambda output_dir: TestExperiment(config=config),
|
||||
output_dir="/tmp"))
|
||||
|
||||
def test_fail_job_name_with_no_default_schedule(self):
|
||||
self.assertRaisesRegexp(ValueError, "Must specify a schedule",
|
||||
learn_runner.run, build_experiment, "/tmp")
|
||||
config = run_config.RunConfig(
|
||||
job_name="foo_has_no_default_schedule",
|
||||
cluster_spec=build_distributed_cluster_spec())
|
||||
create_experiment_fn = lambda output_dir: TestExperiment(config=config)
|
||||
self.assertRaisesRegexp(ValueError, "No default schedule",
|
||||
learn_runner.run, create_experiment_fn, "/tmp")
|
||||
|
||||
def test_fail_non_callable(self):
|
||||
self.assertRaisesRegexp(TypeError, "Experiment builder .* is not callable",
|
||||
@ -148,7 +183,8 @@ class MainTest(tf.test.TestCase):
|
||||
"default")
|
||||
|
||||
def test_fail_schedule_from_config_with_no_job_name(self):
|
||||
config = run_config.RunConfig(job_name=None)
|
||||
config = run_config.RunConfig(
|
||||
job_name=None, cluster_spec=build_distributed_cluster_spec())
|
||||
self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"Must specify a schedule",
|
||||
|
@ -77,13 +77,6 @@ Node* Ones(Graph* const g, const int n) {
|
||||
return test::graph::Constant(g, data);
|
||||
}
|
||||
|
||||
Node* StringIota(Graph* const g, const int n) {
|
||||
Tensor data(DT_STRING, TensorShape({n}));
|
||||
test::FillFn<string>(
|
||||
&data, [](const int i) { return strings::StrCat(strings::Hex(i)); });
|
||||
return test::graph::Constant(g, data);
|
||||
}
|
||||
|
||||
Node* SparseExampleIndices(Graph* const g, const int sparse_features_per_group,
|
||||
const int num_examples) {
|
||||
const int x_size = num_examples * 4;
|
||||
|
@ -60,9 +60,6 @@ HOST_OBJDIR := $(MAKEFILE_DIR)/gen/host_obj/
|
||||
HOST_BINDIR := $(MAKEFILE_DIR)/gen/host_bin/
|
||||
HOST_GENDIR := $(MAKEFILE_DIR)/gen/host_obj/
|
||||
|
||||
# Find the current Eigen version from the Bazel configuration
|
||||
EIGEN_VERSION := $(shell grep eigen_version tensorflow/workspace.bzl | head -1 | sed -e 's/.*eigen_version.*=.*"\(.*\)"/\1/')
|
||||
|
||||
# Settings for the host compiler.
|
||||
HOST_CXX := $(CC_PREFIX) gcc
|
||||
HOST_CXXFLAGS := --std=c++11
|
||||
@ -75,7 +72,7 @@ HOST_LDOPTS += -L/usr/local/lib
|
||||
HOST_INCLUDES := \
|
||||
-I. \
|
||||
-I$(MAKEFILE_DIR)/downloads/ \
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen-eigen-$(EIGEN_VERSION) \
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen \
|
||||
-I$(HOST_GENDIR)
|
||||
ifeq ($(HAS_GEN_HOST_PROTOC),true)
|
||||
HOST_INCLUDES += -I$(MAKEFILE_DIR)/gen/protobuf-host/include
|
||||
@ -148,7 +145,7 @@ DEPFLAGS = -MT $@ -MMD -MP -MF $(DEPDIR)/$*.Td
|
||||
INCLUDES := \
|
||||
-I. \
|
||||
-I$(MAKEFILE_DIR)/downloads/ \
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen-eigen-$(EIGEN_VERSION) \
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen \
|
||||
-I$(PROTOGENDIR) \
|
||||
-I$(PBTGENDIR)
|
||||
ifeq ($(HAS_GEN_HOST_PROTOC),true)
|
||||
@ -240,7 +237,7 @@ ifeq ($(TARGET),ANDROID)
|
||||
-I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/armeabi/include \
|
||||
-I. \
|
||||
-I$(MAKEFILE_DIR)/downloads/ \
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen-eigen-$(EIGEN_VERSION) \
|
||||
-I$(MAKEFILE_DIR)/downloads/eigen \
|
||||
-I$(MAKEFILE_DIR)/gen/protobuf/include \
|
||||
-I$(PROTOGENDIR) \
|
||||
-I$(PBTGENDIR)
|
||||
@ -570,6 +567,12 @@ clean:
|
||||
rm -rf $(MAKEFILE_DIR)/gen
|
||||
rm -rf tensorflow/core/util/version_info.cc
|
||||
|
||||
# Gets rid of all generated files except protobuf libs generated
|
||||
# before calling make. This allows users not to recompile proto libs everytime.
|
||||
clean_except_protobuf_libs:
|
||||
find $(MAKEFILE_DIR)/gen -mindepth 1 -maxdepth 1 ! -name "protobuf" ! -name "protobuf-host" -exec rm -r "{}" \;
|
||||
rm -rf tensorflow/core/util/version_info.cc
|
||||
|
||||
# Gets rid of target files only, leaving the host alone. Also leaves the lib
|
||||
# directory untouched deliberately, so we can persist multiple architectures
|
||||
# across builds for iOS.
|
||||
|
@ -46,18 +46,19 @@ shift $((OPTIND - 1))
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd ${SCRIPT_DIR}/../../../
|
||||
|
||||
# Remove any old files first.
|
||||
make -f tensorflow/contrib/makefile/Makefile clean
|
||||
|
||||
if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then
|
||||
# Remove any old files first.
|
||||
make -f tensorflow/contrib/makefile/Makefile clean
|
||||
rm -rf tensorflow/contrib/makefile/downloads
|
||||
# Pull down the required versions of the frameworks we need.
|
||||
tensorflow/contrib/makefile/download_dependencies.sh
|
||||
fi
|
||||
|
||||
# Compile protobuf for the target Android device architectures.
|
||||
# Compile protobuf for the target Android device architectures.
|
||||
CC_PREFIX="${CC_PREFIX}" NDK_ROOT="${NDK_ROOT}" \
|
||||
tensorflow/contrib/makefile/compile_android_protobuf.sh -c
|
||||
else
|
||||
# Only clean files generated by make
|
||||
make -f tensorflow/contrib/makefile/Makefile clean_except_protobuf_libs
|
||||
fi
|
||||
|
||||
if [[ "${USE_HEXAGON}" == "true" ]]; then
|
||||
HEXAGON_PARENT_DIR=$(cd ../hexagon && pwd)
|
||||
|
@ -1,4 +1,4 @@
|
||||
#!/bin/bash -ex
|
||||
#!/bin/bash
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -14,57 +14,52 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads
|
||||
BZL_FILE_PATH=tensorflow/workspace.bzl
|
||||
|
||||
mkdir -p ${DOWNLOADS_DIR}
|
||||
EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/.*tar\.gz' "${BZL_FILE_PATH}")"
|
||||
GEMMLOWP_URL="$(grep -o 'http.*github.com/google/gemmlowp/.*tar\.gz' "${BZL_FILE_PATH}")"
|
||||
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
|
||||
PROTOBUF_URL="$(grep -o 'http.*github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}")"
|
||||
RE2_URL="$(grep -o 'http.*github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}")"
|
||||
|
||||
# Grab the current Eigen version name from the Bazel build file
|
||||
EIGEN_HASH=$(cat "${BZL_FILE_PATH}" | egrep "eigen_version.*=.*\".*\"" | awk '{ print $3 }')
|
||||
# Trim trailing and preceding double quotes
|
||||
EIGEN_HASH="${EIGEN_HASH%\"}"
|
||||
EIGEN_HASH="${EIGEN_HASH#\"}"
|
||||
|
||||
if [[ -z "${EIGEN_HASH}" ]]; then
|
||||
echo >&2 "Eigen hash does not exist."
|
||||
exit 1
|
||||
else
|
||||
echo "Eigen hash = ${EIGEN_HASH}"
|
||||
fi
|
||||
|
||||
curl "https://bitbucket.org/eigen/eigen/get/${EIGEN_HASH}.tar.gz" \
|
||||
-o /tmp/eigen-${EIGEN_HASH}.tar.gz
|
||||
tar xzf /tmp/eigen-${EIGEN_HASH}.tar.gz -C ${DOWNLOADS_DIR}
|
||||
|
||||
# Link to the downloaded Eigen library from a permanent directory name, since
|
||||
# the downloaded name changes with every version.
|
||||
cd ${DOWNLOADS_DIR}
|
||||
rm -rf eigen-latest
|
||||
ln -s eigen-eigen-${EIGEN_HASH} eigen-latest
|
||||
|
||||
# TODO(petewarden) - Some new code in Eigen triggers a clang bug with iOS arm64,
|
||||
# so work around it by patching the source.
|
||||
function replace_by_sed() {
|
||||
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
|
||||
# so work around it by patching the source.
|
||||
replace_by_sed() {
|
||||
local regex="${1}"
|
||||
shift
|
||||
if echo "${OSTYPE}" | grep -q darwin; then
|
||||
sed -e "$1" -i '' "$2"
|
||||
sed -i '' -e "${regex}" "$@"
|
||||
else
|
||||
sed -e "$1" -i "$2"
|
||||
sed -i -e "${regex}" "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
download_and_extract() {
|
||||
local usage="Usage: download_and_extract URL DIR"
|
||||
local url="${1:?${usage}}"
|
||||
local dir="${2:?${usage}}"
|
||||
echo "downloading ${url}" >&2
|
||||
mkdir -p "${dir}"
|
||||
tar -C "${dir}" --strip-components=1 -xz < <(curl -Ls "${url}")
|
||||
}
|
||||
|
||||
download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen"
|
||||
download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp"
|
||||
download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest"
|
||||
download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf"
|
||||
download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2"
|
||||
|
||||
replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \
|
||||
eigen-latest/Eigen/src/Core/arch/NEON/Complex.h
|
||||
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
|
||||
replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \
|
||||
eigen-latest/Eigen/src/Core/arch/NEON/Complex.h
|
||||
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
|
||||
replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \
|
||||
eigen-latest/Eigen/src/Core/arch/NEON/Complex.h
|
||||
|
||||
git clone https://github.com/google/re2.git re2
|
||||
git clone https://github.com/google/gemmlowp.git gemmlowp
|
||||
git clone https://github.com/google/protobuf.git protobuf
|
||||
git clone https://github.com/google/googletest.git googletest
|
||||
|
||||
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
|
||||
# TODO(satok): Remove this once protobuf/autogen.sh is fixed.
|
||||
replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#http://download.tensorflow.org/deps/gmock-1.7.0.zip#' \
|
||||
protobuf/autogen.sh
|
||||
"${DOWNLOADS_DIR}/protobuf/autogen.sh"
|
||||
|
||||
echo "download_dependencies.sh completed successfully."
|
||||
echo "download_dependencies.sh completed successfully." >&2
|
||||
|
@ -35,7 +35,6 @@ tensorflow/core/lib/io/record_writer.cc
|
||||
tensorflow/core/lib/io/record_reader.cc
|
||||
tensorflow/core/lib/io/random_inputstream.cc
|
||||
tensorflow/core/lib/io/path.cc
|
||||
tensorflow/core/lib/io/match.cc
|
||||
tensorflow/core/lib/io/iterator.cc
|
||||
tensorflow/core/lib/io/inputstream_interface.cc
|
||||
tensorflow/core/lib/io/inputbuffer.cc
|
||||
|
@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.pb.h
|
||||
tensorflow/core/protobuf/named_tensor.pb.h
|
||||
tensorflow/core/protobuf/meta_graph.pb.h
|
||||
tensorflow/core/protobuf/config.pb.h
|
||||
tensorflow/core/protobuf/tensor_bundle.pb.h
|
||||
tensorflow/core/lib/core/error_codes.pb.h
|
||||
tensorflow/core/framework/versions.pb.h
|
||||
tensorflow/core/framework/variable.pb.h
|
||||
|
@ -2,6 +2,7 @@ tensorflow/core/util/saved_tensor_slice.pb_text.cc
|
||||
tensorflow/core/util/memmapped_file_system.pb_text.cc
|
||||
tensorflow/core/protobuf/saver.pb_text.cc
|
||||
tensorflow/core/protobuf/config.pb_text.cc
|
||||
tensorflow/core/protobuf/tensor_bundle.pb_text.cc
|
||||
tensorflow/core/lib/core/error_codes.pb_text.cc
|
||||
tensorflow/core/framework/versions.pb_text.cc
|
||||
tensorflow/core/framework/types.pb_text.cc
|
||||
|
@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.proto
|
||||
tensorflow/core/protobuf/named_tensor.proto
|
||||
tensorflow/core/protobuf/meta_graph.proto
|
||||
tensorflow/core/protobuf/config.proto
|
||||
tensorflow/core/protobuf/tensor_bundle.proto
|
||||
tensorflow/core/lib/core/error_codes.proto
|
||||
tensorflow/core/framework/versions.proto
|
||||
tensorflow/core/framework/variable.proto
|
||||
|
@ -118,6 +118,7 @@ time.
|
||||
@@streaming_mean_cosine_distance
|
||||
@@streaming_percentage_less
|
||||
@@streaming_sensitivity_at_specificity
|
||||
@@streaming_sparse_average_precision_at_k
|
||||
@@streaming_sparse_precision_at_k
|
||||
@@streaming_sparse_recall_at_k
|
||||
@@streaming_specificity_at_sensitivity
|
||||
@ -167,6 +168,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_recall_at
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_recall_at_thresholds
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_root_mean_squared_error
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sensitivity_at_specificity
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_average_precision_at_k
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_precision_at_k
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_recall_at_k
|
||||
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_specificity_at_sensitivity
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -46,7 +46,7 @@ class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
|
||||
<< ", hexagon binary version = "
|
||||
<< hexagon_gemm_wrapper_GetHexagonBinaryVersion() << ")";
|
||||
LOG(INFO) << "Cpu frequency = "
|
||||
<< profile_utils::CpuUtils::GetCpuFrequency();
|
||||
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
|
||||
#else
|
||||
LOG(WARNING) << "Hexagon libs are not linked.";
|
||||
#endif
|
||||
|
@ -663,7 +663,7 @@ def train(train_op,
|
||||
raise ValueError('Cannot provide trace_every_n_steps because '
|
||||
'logdir=None')
|
||||
|
||||
if sync_optimizer and startup_delay_steps > 0:
|
||||
if sync_optimizer is not None and startup_delay_steps > 0:
|
||||
raise ValueError(
|
||||
'startup_delay_steps must be zero when sync_optimizer is supplied.')
|
||||
|
||||
@ -697,7 +697,7 @@ def train(train_op,
|
||||
|
||||
cleanup_op = None
|
||||
|
||||
if is_chief and sync_optimizer:
|
||||
if is_chief and sync_optimizer is not None:
|
||||
if not isinstance(sync_optimizer,
|
||||
sync_replicas_optimizer.SyncReplicasOptimizer):
|
||||
raise ValueError(
|
||||
@ -761,7 +761,7 @@ def train(train_op,
|
||||
number_of_steps or sys.maxint))
|
||||
sv.start_queue_runners(sess)
|
||||
logging.info('Starting Queues.')
|
||||
if is_chief and sync_optimizer:
|
||||
if is_chief and sync_optimizer is not None:
|
||||
sv.start_queue_runners(sess, [chief_queue_runner])
|
||||
try:
|
||||
while not sv.should_stop():
|
||||
|
68
tensorflow/contrib/tensorboard/BUILD
Normal file
68
tensorflow/contrib/tensorboard/BUILD
Normal file
@ -0,0 +1,68 @@
|
||||
# Description:
|
||||
# TensorBoard module containing volatile or experimental code.
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
# For platform specific build config
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||
|
||||
tf_proto_library(
|
||||
name = "protos_all",
|
||||
srcs = glob(["**/*.proto"]),
|
||||
go_api_version = 2,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# API methods in `tf.contrib.tensorboard` package.
|
||||
py_library(
|
||||
name = "tensorboard",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":plugins"],
|
||||
)
|
||||
|
||||
# API methods in `tf.contrib.tensorboard.plugins` package.
|
||||
py_library(
|
||||
name = "plugins",
|
||||
srcs = ["plugins/__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":projector"],
|
||||
)
|
||||
|
||||
# API methods and protos in `tf.contrib.tensorboard.plugins.projector` package.
|
||||
py_library(
|
||||
name = "projector",
|
||||
srcs = ["plugins/projector/__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":protos_all_py",
|
||||
"//tensorflow/python:lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "projector_api_test",
|
||||
size = "small",
|
||||
srcs = ["plugins/projector/projector_api_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":projector",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
22
tensorflow/contrib/tensorboard/__init__.py
Normal file
22
tensorflow/contrib/tensorboard/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""tensorboard module containing volatile or experimental code."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Add projects here, they will show up under tf.contrib.tensorboard.
|
||||
from tensorflow.contrib.tensorboard import plugins
|
22
tensorflow/contrib/tensorboard/plugins/__init__.py
Normal file
22
tensorflow/contrib/tensorboard/plugins/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""tensorboard plugins module containing volatile or experimental code."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Add projects here, they will show up under tf.contrib.tensorboard.plugins
|
||||
from tensorflow.contrib.tensorboard.plugins import projector
|
54
tensorflow/contrib/tensorboard/plugins/projector/__init__.py
Normal file
54
tensorflow/contrib/tensorboard/plugins/projector/__init__.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Public API for the Embedding Projector."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
|
||||
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
|
||||
from tensorflow.python.lib.io import file_io
|
||||
|
||||
PROJECTOR_FILENAME = 'projector_config.pbtxt'
|
||||
|
||||
|
||||
def visualize_embeddings(summary_writer, config):
|
||||
"""Stores a config file used by the embedding projector.
|
||||
|
||||
Args:
|
||||
summary_writer: The summary writer used for writting events.
|
||||
config: `tf.contrib.tensorboard.plugins.projector.ProjectorConfig`
|
||||
proto that holds the configuration for the projector such as paths to
|
||||
checkpoint files and metadata files for the embeddings. If
|
||||
`config.model_checkpoint_path` is none, it defaults to the
|
||||
`logdir` used by the summary_writer.
|
||||
|
||||
Raises:
|
||||
ValueError: If the summary writer does not have a `logdir`.
|
||||
"""
|
||||
logdir = summary_writer.get_logdir()
|
||||
|
||||
# Sanity checks.
|
||||
if logdir is None:
|
||||
raise ValueError('Summary writer must have a logdir')
|
||||
|
||||
# Saving the config file in the logdir.
|
||||
config_pbtxt = text_format.MessageToString(config)
|
||||
file_io.write_string_to_file(
|
||||
os.path.join(logdir, PROJECTOR_FILENAME), config_pbtxt)
|
@ -0,0 +1,49 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""API tests for the projector plugin in TensorBoard."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
|
||||
class ProjectorApiTest(tf.test.TestCase):
|
||||
|
||||
def testVisualizeEmbeddings(self):
|
||||
# Create a dummy configuration.
|
||||
config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
|
||||
config.model_checkpoint_path = 'test'
|
||||
emb1 = config.embedding.add()
|
||||
emb1.tensor_name = 'tensor1'
|
||||
emb1.metadata_path = 'metadata1'
|
||||
|
||||
# Call the API method to save the configuration to a temporary dir.
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
writer = tf.train.SummaryWriter(temp_dir)
|
||||
tf.contrib.tensorboard.plugins.projector.visualize_embeddings(writer,
|
||||
config)
|
||||
|
||||
# Read the configuratin from disk and make sure it matches the original.
|
||||
with tf.gfile.GFile(os.path.join(temp_dir, 'projector_config.pbtxt')) as f:
|
||||
config2 = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
|
||||
text_format.Parse(f.read(), config2)
|
||||
self.assertEqual(config, config2)
|
@ -13,19 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
syntax = "proto3";
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
package tensorflow;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(FileSystemTest, GetNameFromURI) {
|
||||
EXPECT_EQ("foo", GetNameFromURI("file://foo"));
|
||||
EXPECT_EQ("file:/", GetNameFromURI("file:/"));
|
||||
EXPECT_EQ("file:", GetNameFromURI("file:"));
|
||||
EXPECT_EQ("bar", GetNameFromURI("bar"));
|
||||
message EmbeddingInfo {
|
||||
string tensor_name = 1;
|
||||
string metadata_path = 2;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
message ProjectorConfig {
|
||||
string model_checkpoint_path = 1;
|
||||
repeated EmbeddingInfo embedding = 2;
|
||||
}
|
453
tensorflow/contrib/tfprof/README.md
Normal file
453
tensorflow/contrib/tfprof/README.md
Normal file
@ -0,0 +1,453 @@
|
||||
# tfprof: A Profiling Tool for TensorFlow Models
|
||||
|
||||
go/tfprof
|
||||
|
||||
Author: Xin Pan (xpan@google.com, github: panyx0718)
|
||||
|
||||
Consultants: Jon Shlens (shlens@google.com), Pete Warden (petewarden@google.com)
|
||||
|
||||
[TOC]
|
||||
|
||||
## Introduction
|
||||
|
||||
tfprof is a profiling tool for TensorFlow that analyzes model architectures
|
||||
and measures system performance.
|
||||
|
||||
###Major Features
|
||||
|
||||
1. Measure model parameters, float operations, tensor shapes.
|
||||
2. Measure op execution times, requested memory size and device placement.
|
||||
3. Inspect checkpoint tensors' shapes and their values.
|
||||
4. Explore model based on name scope or graph structure.
|
||||
5. Selectively grouping/filtering/accounting/ordering ops.
|
||||
|
||||
### Interfaces
|
||||
|
||||
[CLI Tutorials](#cli-tutorials):
|
||||
It supports interactive mode for exploration and single-shot mode for
|
||||
scripts. Outputs can be dumped to files or printed in terminal.
|
||||
|
||||
Python API Tutorials: Python API is not released yet.
|
||||
|
||||
## CLI Tutorials
|
||||
|
||||
Tutorials are based on a 32 layers ResNet.
|
||||
TODO(xpan): Provide graph.pbtxt, model.ckpt, tfprof_log and run_meta download.
|
||||
|
||||
### Examples
|
||||
|
||||
1) Start `tfprof` command line tool
|
||||
|
||||
```shell
|
||||
# Build the tool.
|
||||
bazel build -c opt tensorflow/contrib/tfprof/...
|
||||
|
||||
# Help information, including detail 'option' instructions.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof help
|
||||
#
|
||||
# The following commands will start tfprof interactive mode.
|
||||
#
|
||||
# Profile model shapes and parameters only.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path=/graph.pbtxt
|
||||
#
|
||||
# Additionally profile checkpoint statistics and values.
|
||||
# Use '-account_type_regexes _checkpoint_variables' to select
|
||||
# checkpoint tensors.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path=graph.pbtxt \
|
||||
--checkpoint_path=model.ckpt
|
||||
#
|
||||
# Additionally profile ops requested memory and timing.
|
||||
# See CLI Input Files section on generating run_meta file.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path=graph.pbtxt \
|
||||
--run_meta_path=run_meta \
|
||||
--checkpoint_path=model.ckpt
|
||||
#
|
||||
# tfprof_log is used to define customized op types and float ops.
|
||||
# Use tfprof_logger.write_op_log() to create tfprof_log.
|
||||
# See 11) in Examples section on generating tfprof_log file.
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path=graph.pbtxt \
|
||||
--run_meta_path=run_meta \
|
||||
--op_log_path=tfprof_log \
|
||||
--checkpoint_path=model.ckpt
|
||||
```
|
||||
Note that `graph.pbtxt` is an ASCII text format.
|
||||
|
||||
2) Press enter to show the default options
|
||||
|
||||
```shell
|
||||
tfprof>
|
||||
tfprof>
|
||||
-max_depth 4
|
||||
-min_bytes 0
|
||||
-min_micros 0
|
||||
-min_params 0
|
||||
-min_float_ops 0
|
||||
-device_regexes .*
|
||||
-order_by name
|
||||
-account_type_regexes Variable
|
||||
-start_name_regexes .*
|
||||
-trim_name_regexes
|
||||
-show_name_regexes .*
|
||||
-hide_name_regexes IsVariableInitialized_[0-9]+,save\/.*,^zeros[0-9_]*
|
||||
-account_displayed_op_only false
|
||||
# supported select fileds. Availability depends on --[run_meta|checkpoint|op_log]_path.
|
||||
# [bytes|micros|params|float_ops|num_hidden_ops|tensor_value|device|op_types]
|
||||
-select params
|
||||
-viz false
|
||||
-dump_to_file
|
||||
```
|
||||
|
||||
3) I want to see the `BatchNorm`'s gamma value in checkpoint.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --checkpoint_path.
|
||||
tfprof> scope -show_name_regexes unit_1_0.*gamma -select tensor_value -max_depth 5
|
||||
_TFProfRoot ()
|
||||
unit_1_0/shared_activation/init_bn/gamma ()
|
||||
[1.80 2.10 2.06 1.91 2.26 1.86 1.81 1.37 1.78 1.85 1.96 1.54 2.04 2.34 2.22 1.99 ],
|
||||
unit_1_0/sub2/bn2/gamma ()
|
||||
[1.57 1.83 1.30 1.25 1.59 1.14 1.26 0.82 1.19 1.10 1.48 1.01 0.82 1.23 1.21 1.14 ],
|
||||
```
|
||||
|
||||
4) I want to see my checkpoint tensors shape and number of parameters.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --checkpoint_path.
|
||||
# Increase -max_depth to see all tensors.
|
||||
tfprof> scope -account_type_regexes _checkpoint_variables -select params -max_depth 4
|
||||
_TFProfRoot (--/930.58k params)
|
||||
global_step (0/0 params)
|
||||
init/init_conv/DW (3x3x3x16, 432/864 params)
|
||||
pool_logit/DW (64x10, 640/1.28k params)
|
||||
pool_logit/DW/Momentum (64x10, 640/640 params)
|
||||
pool_logit/biases (10, 10/20 params)
|
||||
pool_logit/biases/Momentum (10, 10/10 params)
|
||||
unit_last/final_bn/beta (64, 64/128 params)
|
||||
unit_last/final_bn/gamma (64, 64/128 params)
|
||||
unit_last/final_bn/moving_mean (64, 64/64 params)
|
||||
unit_last/final_bn/moving_variance (64, 64/64 params)
|
||||
```
|
||||
|
||||
5) I defined an op named ‘cost’ to calculate the loss. I want to know what ops
|
||||
it depends on take a long time to run. Hint: Use the ‘graph’ command to explore
|
||||
graph dependencies.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --run_meta_path.
|
||||
tfprof> graph -start_name_regexes cost.* -max_depth 100 -min_micros 10000 -select micros -account_type_regexes .*
|
||||
_TFProfRoot (0us/3.61sec)
|
||||
init/init_conv/Conv2D (11.75ms/3.10sec)
|
||||
random_shuffle_queue_DequeueMany (3.09sec/3.09sec)
|
||||
unit_1_0/sub2/conv2/Conv2D (74.14ms/3.19sec)
|
||||
unit_1_3/sub2/conv2/Conv2D (60.75ms/3.34sec)
|
||||
unit_2_4/sub2/conv2/Conv2D (73.58ms/3.54sec)
|
||||
unit_3_3/sub2/conv2/Conv2D (10.26ms/3.60sec)
|
||||
```
|
||||
|
||||
6) I want to know the expensive operations during the back propagation.
|
||||
Hint: tensorflow prepend ‘gradient’ to your defined name scopes. Use the ‘scope’
|
||||
command to explore based on name scope hierarchies.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --run_meta_path.
|
||||
tfprof> scope -start_name_regexes gradient.* -max_depth 100 -min_micros 20000 -select micros -account_type_regexes .*
|
||||
_TFProfRoot (0us/2.29sec)
|
||||
gradients/unit_1_0/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (54.96ms/54.96ms)
|
||||
gradients/unit_1_0/sub2/conv2/Conv2D_grad/Conv2DBackpropFilter (83.63ms/83.63ms)
|
||||
gradients/unit_1_1/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (99.25ms/99.25ms)
|
||||
gradients/unit_1_2/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (95.40ms/95.40ms)
|
||||
gradients/unit_1_2/sub2/conv2/Conv2D_grad/Conv2DBackpropFilter (99.83ms/99.83ms)
|
||||
gradients/unit_1_3/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (95.39ms/95.39ms)
|
||||
...
|
||||
```
|
||||
|
||||
7) Show the number of float operations in the model.
|
||||
Note: float operations calculation depends on
|
||||
1) op.RegisterStatistics. If an op doesn’t
|
||||
have RegisterStatistics defined, its float operations cannot be counted.
|
||||
2) fully defined shape is also necessary in order to calculate flops.
|
||||
float operations number is provided by tensorflow::tfprof::OpLog logged from
|
||||
Python API.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path, --op_log_path.
|
||||
tfprof> scope -min_float_ops 1 -max_depth 10 -select float_ops -account_type_regexes .*
|
||||
_TFProfRoot (0/17.63b flops)
|
||||
gradients/pool_logit/xw_plus_b/MatMul_grad/MatMul (163.84k/163.84k flops)
|
||||
gradients/pool_logit/xw_plus_b/MatMul_grad/MatMul_1 (163.84k/163.84k flops)
|
||||
init/init_conv/Conv2D (113.25m/113.25m flops)
|
||||
pool_logit/xw_plus_b (1.28k/165.12k flops)
|
||||
pool_logit/xw_plus_b/MatMul (163.84k/163.84k flops)
|
||||
unit_1_0/sub1/conv1/Conv2D (603.98m/603.98m flops)
|
||||
unit_1_0/sub2/conv2/Conv2D (603.98m/603.98m flops)
|
||||
unit_1_1/sub1/conv1/Conv2D (603.98m/603.98m flops)
|
||||
unit_1_1/sub2/conv2/Conv2D (603.98m/603.98m flops)
|
||||
...
|
||||
```
|
||||
|
||||
8) Show the number of parameters of all `tf.trainable_variables()` in the model.
|
||||
|
||||
```shell
|
||||
# Requires --graph_path --op_log_path.
|
||||
# store option for future commands.
|
||||
tfprof> set -account_type_regexes _trainable_variables
|
||||
tfprof> scope -max_depth 4 -select params
|
||||
_TFProfRoot (--/464.15k params)
|
||||
init/init_conv/DW (3x3x3x16, 432/432 params)
|
||||
pool_logit/DW (64x10, 640/640 params)
|
||||
pool_logit/biases (10, 10/10 params)
|
||||
unit_last/final_bn/beta (64, 64/64 params)
|
||||
unit_last/final_bn/gamma (64, 64/64 params)
|
||||
```
|
||||
|
||||
Where does “_trainable_variables” come from? It is from the OpLog file
|
||||
generated by write_op_log() Python API. write_op_log() help users create some
|
||||
common op types implicitly. Users can define their own op types and log it
|
||||
through the write_op_log() API.
|
||||
|
||||
9) What if I’m lazy and don’t want to define op type? I have given my ops
|
||||
well-defined names in my model’s code. And want to use names to select a group
|
||||
of ops. Let’s try it!
|
||||
|
||||
```shell
|
||||
tfprof> set -account_type_regexes .*
|
||||
tfprof> scope -show_name_regexes unit_2_1.*DW -max_depth 100 -account_displayed_op_only
|
||||
_TFProfRoot (0/18.43k params)
|
||||
unit_2_1/sub1/conv1/DW (3x3x32x32, 9.22k/9.22k params)
|
||||
unit_2_1/sub2/conv2/DW (3x3x32x32, 9.22k/9.22k params)
|
||||
```
|
||||
|
||||
The above command allows you to filter ops that match specific names.
|
||||
`-account_displayed_op_only` asks tfprof to only account ops displayed
|
||||
in terminal. Otherwise, tfprof accounts all ops matched by
|
||||
`-account_type_regexes` recursively even if they are hidden due to some
|
||||
options such as -max_depth.
|
||||
|
||||
10) TensorFlow has built-in op types. For example, built-in op type `Variable`
|
||||
seems to include `Variable's` created by your model. However, be careful when
|
||||
depending on it because TensorFlow creates extra `Variable` ops implicitly and
|
||||
the implicitly created ops can have the same prefix as the `Variable's` you
|
||||
defined.
|
||||
|
||||
In the following example, extra `Variables` are created and “/Momentum” is
|
||||
appended to their names. This might cause you “model capacity” calculation
|
||||
to get wrong.
|
||||
|
||||
```shell
|
||||
tfprof> scope -account_type_regexes Variable -max_depth 4 -select params
|
||||
_TFProfRoot (--/930.58k params)
|
||||
global_step (1/1 params)
|
||||
init/init_conv/DW (3x3x3x16, 432/864 params)
|
||||
pool_logit/DW (64x10, 640/1.28k params)
|
||||
pool_logit/DW/Momentum (64x10, 640/640 params)
|
||||
pool_logit/biases (10, 10/20 params)
|
||||
pool_logit/biases/Momentum (10, 10/10 params)
|
||||
unit_last/final_bn/beta (64, 64/128 params)
|
||||
unit_last/final_bn/gamma (64, 64/128 params)
|
||||
unit_last/final_bn/moving_mean (64, 64/64 params)
|
||||
unit_last/final_bn/moving_variance (64, 64/64 params)
|
||||
```
|
||||
|
||||
|
||||
11) A example of defining extra op type for ops using `OpLog`
|
||||
|
||||
First, in Python code, create an `OpLog` proto and add op type
|
||||
information to it:
|
||||
|
||||
```python
|
||||
op_log = tfprof_log_pb2.OpLog()
|
||||
entry = op_log.log_entries.add()
|
||||
entry.name = 'pool_logit/DW'
|
||||
entry.types.append('pool_logit')
|
||||
entry = op_log.log_entries.add()
|
||||
entry.name = 'pool_logit/biases'
|
||||
# Alternatively:
|
||||
# var = tf.get_variable(xxx)
|
||||
# entry.name = var.op.name
|
||||
entry.types.append('pool_logit')
|
||||
```
|
||||
|
||||
Second, call write_op_log to write the OpLog proto.
|
||||
|
||||
```python
|
||||
tfprof_logger.write_op_log(sess.graph, /tmp/my_op_log_dir, op_log)
|
||||
```
|
||||
|
||||
Third, when starting the tfprof tool, specify
|
||||
"--op_log_path /tmp/my_op_log_dir/op_log"
|
||||
|
||||
```shell
|
||||
tfprof> scope -account_type_regexes pool_logit -max_depth 4 -select params
|
||||
_TFProfRoot (--/650 params)
|
||||
pool_logit/DW (64x10, 640/640 params)
|
||||
pool_logit/biases (10, 10/10 params)
|
||||
```
|
||||
|
||||
Note that when you call
|
||||
`tfprof_logger.write_op_log(...)`, the tool adds all `Variables` inside
|
||||
`tf.trainable_variables()` to `_trainable_variables`.
|
||||
|
||||
12) Run tfprof in one-shot mode and dump result to file.
|
||||
|
||||
```shell
|
||||
# Printed to stdout if --dump_to_file is not set.
|
||||
tfprof scope --graph_path /cns/ij-d/home/xpan/tfprof/graph.pbtxt \
|
||||
--max_depth 3 \
|
||||
--dump_to_file "/tmp/dump"
|
||||
Reading Files...
|
||||
Parsing GraphDef...
|
||||
Preparing Views...
|
||||
|
||||
cat /tmp/dump
|
||||
_TFProfRoot (--/930.58k params)
|
||||
global_step (0/0 params)
|
||||
pool_logit/DW (64x10, 640/1.28k params)
|
||||
pool_logit/biases (10, 10/20 params)
|
||||
```
|
||||
|
||||
13) Analyze how balanced Variable are on parameter servers.
|
||||
|
||||
In this tutorial, I'm going to use a seq2seq model, which are split
|
||||
on several gpus at workers and several parameter servers.
|
||||
|
||||
In tfprof, 'device' is an op_type. For example, if op1 and op2 are placed on
|
||||
gpu0. They share an op_type called 'gpu0'.
|
||||
|
||||
```shell
|
||||
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
|
||||
--graph_path ~/tfprof/textsum/graph.pbtxt \
|
||||
--run_meta_path ~/tfprof/textsum/run_meta
|
||||
|
||||
# Looks like ps task 1 is holding twice more parameters than task 0.
|
||||
tfprof> scope -select device,params -account_type_regexes .*ps.*task:0.* -max_depth 1
|
||||
_TFProfRoot (--/25.81m params)
|
||||
tfprof> scope -select device,params -account_type_regexes .*ps.*task:1.* -max_depth 1
|
||||
_TFProfRoot (--/58.84m params)
|
||||
```
|
||||
|
||||
### CLI Input Files
|
||||
|
||||
tfprof command line inference (CLI) loads dumped files from a tensorflow model.
|
||||
Convert them into in-memory data structures. To use it, users need to specify
|
||||
the locations of the dumped files. The following are the dumped files loaded
|
||||
by tfprof:
|
||||
|
||||
<b>--graph_path:</b> GraphDef text file (required). Used to build in-memory
|
||||
representation of the model. For example, graph.pbtxt written by tf.Supervisor
|
||||
is a candidate. If you are not using tf.Supervisor, you can easily get GraphDef
|
||||
using tf.Graph.as_graph_def() or other API.
|
||||
|
||||
<b>--run_meta_path:</b> tensorflow::RunMetadata.
|
||||
Used to get the memory and time consumption of
|
||||
each op of the model. Users need to enable it. For example, the following code
|
||||
snippet writes a RunMetadata file:
|
||||
|
||||
```python
|
||||
run_options = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
# Once a while, call it the get the RunMeta.
|
||||
_ = self._sess.run(..., options=run_options, run_metadata=run_metadata)
|
||||
with gfile.Open(os.path.join(output_dir, "run_meta"), "w") as f:
|
||||
f.write(run_metadata.SerializeToString())
|
||||
```
|
||||
|
||||
<b>--op_log_path:</b>
|
||||
tensorflow::tfprof::OpLog. A proto used to provide extra op information
|
||||
for ops. By giving a group of ops a type name, users can easily aggregate the
|
||||
statistics for those ops without accidently missing or including extra ops.
|
||||
tfprof exposes the following Python API to add op information and logging.
|
||||
|
||||
```python
|
||||
def write_op_log(graph, log_dir, op_log=None)
|
||||
```
|
||||
|
||||
<b>--checkpoint_path:</b>
|
||||
TensorFlow checkpoint. It defines _checkpoint_variable op type. It also
|
||||
provides checkpointed tensors' values.
|
||||
|
||||
|
||||
## Design
|
||||
|
||||
|
||||
### In-memory representation
|
||||
|
||||
<b>Scope:</b> This representation organizes ops based on name scope hierarchy,
|
||||
similar to filesystem hierarchy. Hence, it is essentially a tree data structure.
|
||||
For example op1 with name “name1/name2” is a child of op2 with name “name1”.
|
||||
|
||||
<b>Graph:</b> The representation organizes ops based on op inputs. Hence it is
|
||||
a graph structure. The graph is a “directed acyclic graph” (hopefully), with
|
||||
direction from “output to input”. The direction is design this way so that users
|
||||
can trace from “result” to its “sources”.
|
||||
|
||||
### Command line options
|
||||
|
||||
tfprof’s major goals are to measure system performance and quicly analyze
|
||||
model architectures. Hence, its commands and options should allow users to achieve
|
||||
these 2 goals easily.
|
||||
|
||||
<b>graph:</b> It is expected that users will mostly use graph representation to
|
||||
debug system performance. Hence, tfprof supports graph command, which pulls the
|
||||
graph in-memory representation described above.
|
||||
|
||||
<b>scope:</b> It is expected that some users might want to explore their model
|
||||
statistics using the name scope information they defined in the Python codes.
|
||||
Hence, tfprof supports “scope” command, which pulls the tree in-memory
|
||||
representation.
|
||||
|
||||
<b>set:</b> It is used to store the options so that user doesn’t need to
|
||||
re-type the same option again and again in the follow up command line. Note that
|
||||
tfprof has traditional terminal’s history and auto-complete support.
|
||||
|
||||
<b>help:</b> print help information.
|
||||
|
||||
<b>Options:</b> Run “tfprof help” to get detailed explanations.
|
||||
|
||||
```python
|
||||
"-max_depth",
|
||||
"-min_bytes",
|
||||
"-min_micros",
|
||||
"-min_params",
|
||||
"-min_float_ops",
|
||||
"-order_by",
|
||||
"-account_type_regexes",
|
||||
"-start_name_regexes",
|
||||
"-trim_name_regexes",
|
||||
"-show_name_regexes",
|
||||
"-hide_name_regexes",
|
||||
"-account_displayed_op_only",
|
||||
"-select",
|
||||
"-viz", # Only supported for graph command.
|
||||
"-dump_to_file",
|
||||
```
|
||||
|
||||
A key design is that stats are aggregated from descendants up to ancestors.
|
||||
`-account_type_regexes` is used to decide which ops stat is accounted. It makes
|
||||
decision based on op type. Usually set it to `.*` if no extra type information
|
||||
is added to the ops using OpLog. Intuitively, only accounted ops are displayed.
|
||||
`-min/max` and `-show/hide/trim/start` options are only used the optionally
|
||||
displayed or hide ops based on ops’ name and stats. However, they don’t prevent
|
||||
tfprof from accounting stats of hidden ops. Hence, the stat of a op can be
|
||||
aggregated by its parent even if it is hidden. `-account_displayed_op_only` is
|
||||
an option to break this rule. When it is set, only displayed ops are accounted.
|
||||
|
||||
Regexes are all comma-separated, for example `-show_name_regexes`
|
||||
`regex1.*,regex2.*`. It is designed this way because it is convenient and comma
|
||||
is not expected to show up in op names.
|
||||
|
||||
`-order_by` is used to order displayed ops. Displayed ops at the same hierarchy
|
||||
(notice the indent printed) are sorted according to order_by.
|
||||
|
||||
## Future Work
|
||||
|
||||
* Load SummaryWriter event logs so that it can show the latest summary value.
|
||||
|
||||
* Better sorting and aggregation of outputs. Easier comprehension.
|
||||
|
||||
* Currently, shape information is based on `graph.pbtxt`. When the shape
|
||||
information is incomplete, tfprof ignores it. See if it can use `RunMetadata`
|
||||
and `Checkpoint` to complete shape information.
|
31
tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
Normal file
31
tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
Normal file
@ -0,0 +1,31 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
py_library(
|
||||
name = "tfprof_logger",
|
||||
srcs = ["tfprof_logger.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets. These must be at the end for syncrepo.
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
114
tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
Normal file
114
tensorflow/contrib/tfprof/python/tools/tfprof/tfprof_logger.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Logging tensorflow::tfprof::OpLog.
|
||||
|
||||
OpLog is used to add extra model information for offline analysis by tfprof.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_log_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
TRAINABLE_VARIABLES = '_trainable_variables'
|
||||
REGISTERED_FLOP_STATS = 'flops'
|
||||
|
||||
|
||||
def _get_logged_ops(graph):
|
||||
"""Extract trainable model parameters and FLOPs for ops from a Graph.
|
||||
|
||||
Args:
|
||||
graph: tf.Graph.
|
||||
Returns:
|
||||
logged_ops: dict mapping from op_name to OpLogEntry.
|
||||
"""
|
||||
logged_ops = {}
|
||||
|
||||
graph_def = graph.as_graph_def()
|
||||
for node in graph_def.node:
|
||||
try:
|
||||
stats = ops.get_stats_for_node_def(graph, node, REGISTERED_FLOP_STATS)
|
||||
except ValueError:
|
||||
# Catch Exception When shape is incomplete. Skip it.
|
||||
stats = None
|
||||
|
||||
if not stats or not stats.value:
|
||||
continue
|
||||
if node.name not in logged_ops:
|
||||
entry = tfprof_log_pb2.OpLogEntry()
|
||||
entry.name = node.name
|
||||
entry.float_ops = stats.value
|
||||
logged_ops[entry.name] = entry
|
||||
|
||||
for v in graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
|
||||
if v.op.name not in logged_ops:
|
||||
entry = tfprof_log_pb2.OpLogEntry()
|
||||
entry.name = v.op.name
|
||||
entry.types.append(TRAINABLE_VARIABLES)
|
||||
logged_ops[entry.name] = entry
|
||||
else:
|
||||
logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
|
||||
return logged_ops
|
||||
|
||||
|
||||
def _merge_default_with_oplog(graph, op_log=None):
|
||||
"""Merge the tfprof default extra info with caller's op_log.
|
||||
|
||||
Args:
|
||||
graph: tf.Graph.
|
||||
op_log: OpLog proto.
|
||||
Returns:
|
||||
tmp_op_log: Merged OpLog proto.
|
||||
"""
|
||||
tmp_op_log = tfprof_log_pb2.OpLog()
|
||||
logged_ops = _get_logged_ops(graph)
|
||||
if not op_log:
|
||||
tmp_op_log.log_entries.extend(logged_ops.values())
|
||||
else:
|
||||
all_ops = dict()
|
||||
for entry in op_log.log_entries:
|
||||
all_ops[entry.name] = entry
|
||||
for op_name, entry in logged_ops.iteritems():
|
||||
if op_name in all_ops:
|
||||
all_ops[op_name].types.extend(entry.types)
|
||||
if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
|
||||
all_ops[op_name].float_ops = entry.float_ops
|
||||
else:
|
||||
all_ops[op_name] = entry
|
||||
tmp_op_log.log_entries.extend(all_ops.values())
|
||||
return tmp_op_log
|
||||
|
||||
|
||||
def write_op_log(graph, log_dir, op_log=None):
|
||||
"""Log provided 'op_log', and add additional model information below.
|
||||
|
||||
The API also assigns ops in tf.trainable_variables() an op type called
|
||||
'_trainable_variables'.
|
||||
The API also logs 'flops' statistics for ops with op.RegisterStatistics()
|
||||
defined.
|
||||
|
||||
Args:
|
||||
graph: tf.Graph.
|
||||
log_dir: directory to write the log file.
|
||||
op_log: OpLog proto.
|
||||
"""
|
||||
op_log = _merge_default_with_oplog(graph, op_log)
|
||||
|
||||
with tf.gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
|
||||
log.write(op_log.SerializeToString())
|
52
tensorflow/contrib/tfprof/tools/tfprof/BUILD
Normal file
52
tensorflow/contrib/tfprof/tools/tfprof/BUILD
Normal file
@ -0,0 +1,52 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets. These must be at the end for syncrepo.
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "tfprof",
|
||||
srcs = ["tfprof_main.cc"],
|
||||
deps = [
|
||||
":protos_all_cc",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof/internal:tfprof_options",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof/internal:tfprof_stats",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof/internal:tfprof_utils",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@linenoise//:linenoise",
|
||||
],
|
||||
)
|
||||
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||
|
||||
tf_proto_library(
|
||||
name = "protos_all",
|
||||
srcs = glob(
|
||||
["**/*.proto"],
|
||||
),
|
||||
cc_api_version = 2,
|
||||
cc_libs = ["//tensorflow/core:protos_all_cc"],
|
||||
go_api_version = 2,
|
||||
java_api_version = 2,
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
227
tensorflow/contrib/tfprof/tools/tfprof/internal/BUILD
Normal file
227
tensorflow/contrib/tfprof/tools/tfprof/internal/BUILD
Normal file
@ -0,0 +1,227 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_stats",
|
||||
srcs = ["tfprof_stats.cc"],
|
||||
hdrs = ["tfprof_stats.h"],
|
||||
deps = [
|
||||
":tfprof_graph",
|
||||
":tfprof_node",
|
||||
":tfprof_options",
|
||||
":tfprof_scope",
|
||||
":tfprof_show",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_node",
|
||||
srcs = ["tfprof_node.cc"],
|
||||
hdrs = ["tfprof_node.h"],
|
||||
deps = [
|
||||
":tfprof_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_scope",
|
||||
srcs = ["tfprof_scope.cc"],
|
||||
hdrs = ["tfprof_scope.h"],
|
||||
deps = [
|
||||
":tfprof_constants",
|
||||
":tfprof_node",
|
||||
":tfprof_options",
|
||||
":tfprof_show",
|
||||
":tfprof_tensor",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_graph",
|
||||
srcs = ["tfprof_graph.cc"],
|
||||
hdrs = ["tfprof_graph.h"],
|
||||
deps = [
|
||||
":tfprof_constants",
|
||||
":tfprof_node",
|
||||
":tfprof_options",
|
||||
":tfprof_show",
|
||||
":tfprof_tensor",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_show",
|
||||
srcs = ["tfprof_show.cc"],
|
||||
hdrs = ["tfprof_show.h"],
|
||||
deps = [
|
||||
":tfprof_constants",
|
||||
":tfprof_node",
|
||||
":tfprof_options",
|
||||
":tfprof_tensor",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tfprof_show_test",
|
||||
srcs = ["tfprof_show_test.cc"],
|
||||
data = [
|
||||
"testdata/ckpt",
|
||||
"testdata/graph.pbtxt",
|
||||
"testdata/run_meta",
|
||||
"testdata/tfprof_log",
|
||||
],
|
||||
deps = [
|
||||
":tfprof_constants",
|
||||
":tfprof_options",
|
||||
":tfprof_stats",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_utils",
|
||||
srcs = ["tfprof_utils.cc"],
|
||||
hdrs = ["tfprof_utils.h"],
|
||||
deps = [
|
||||
":tfprof_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_options",
|
||||
srcs = ["tfprof_options.cc"],
|
||||
hdrs = ["tfprof_options.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "print_model_analysis",
|
||||
srcs = ["print_model_analysis.cc"],
|
||||
hdrs = ["print_model_analysis.h"],
|
||||
deps = [
|
||||
":tfprof_options",
|
||||
":tfprof_stats",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tfprof_stats_test",
|
||||
srcs = ["tfprof_stats_test.cc"],
|
||||
data = [
|
||||
"testdata/ckpt",
|
||||
"testdata/graph.pbtxt",
|
||||
"testdata/run_meta",
|
||||
"testdata/tfprof_log",
|
||||
],
|
||||
deps = [
|
||||
":tfprof_constants",
|
||||
":tfprof_options",
|
||||
":tfprof_stats",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_tensor",
|
||||
srcs = ["tfprof_tensor.cc"],
|
||||
hdrs = ["tfprof_tensor.h"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tfprof_tensor_test",
|
||||
srcs = ["tfprof_tensor_test.cc"],
|
||||
data = [
|
||||
"testdata/ckpt",
|
||||
"testdata/graph.pbtxt",
|
||||
],
|
||||
deps = [
|
||||
":tfprof_options",
|
||||
":tfprof_stats",
|
||||
":tfprof_utils",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfprof_constants",
|
||||
hdrs = ["tfprof_constants.h"],
|
||||
deps = [
|
||||
],
|
||||
)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets. These must be at the end for syncrepo.
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -0,0 +1,65 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/print_model_analysis.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_stats.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
string PrintModelAnalysis(const string* graph, const string* run_meta,
|
||||
const string* op_log, const string* command,
|
||||
const Options* options) {
|
||||
CHECK(graph) << "graph mustn't be null";
|
||||
CHECK(command) << "command mustn't be null";
|
||||
CHECK(options) << "options mustn't be null";
|
||||
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
|
||||
graph_ptr->ParseFromString(*graph);
|
||||
|
||||
std::unique_ptr<RunMetadata> run_meta_ptr;
|
||||
if (run_meta) {
|
||||
run_meta_ptr.reset(new RunMetadata());
|
||||
run_meta_ptr->ParseFromString(*run_meta);
|
||||
}
|
||||
|
||||
std::unique_ptr<OpLog> op_log_ptr;
|
||||
if (op_log) {
|
||||
op_log_ptr.reset(new OpLog());
|
||||
op_log_ptr->ParseFromString(*op_log);
|
||||
}
|
||||
|
||||
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
|
||||
|
||||
TFStats tf_stats(std::move(graph_ptr), std::move(run_meta_ptr),
|
||||
std::move(op_log_ptr), std::move(ckpt_reader));
|
||||
|
||||
if (options->dump_to_file.empty()) {
|
||||
printf("\n=========================Options=============================\n");
|
||||
printf("%s", options->ToString().c_str());
|
||||
printf("\n==================Model Analysis Report======================\n");
|
||||
TFProfNode root(tf_stats.PrintGraph(*command, *options));
|
||||
printf("\n======================End of Report==========================\n");
|
||||
fflush(stdout);
|
||||
return root.SerializeAsString();
|
||||
}
|
||||
return tf_stats.PrintGraph(*command, *options).SerializeAsString();
|
||||
}
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
@ -0,0 +1,45 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_PRINT_MODEL_ANALYSIS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_PRINT_MODEL_ANALYSIS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/tfprof_log.pb.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/tfprof_output.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
// ***This API is only for swig.***
|
||||
//
|
||||
// Interface defined for Python API swig. Calls the tfprof core API.
|
||||
// 'graph', 'run_meta', 'op_log' are serialized GraphDef, RunMetadata,
|
||||
// OpLog strings, respectively.
|
||||
// 'graph', 'command' and 'options' are required. Others can be nullptr
|
||||
// if not available.
|
||||
string PrintModelAnalysis(const string* graph, const string* run_meta,
|
||||
const string* op_log, const string* command,
|
||||
const Options* options);
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_PRINT_MODEL_ANALYSIS_H_
|
BIN
tensorflow/contrib/tfprof/tools/tfprof/internal/testdata/ckpt
vendored
Normal file
BIN
tensorflow/contrib/tfprof/tools/tfprof/internal/testdata/ckpt
vendored
Normal file
Binary file not shown.
636
tensorflow/contrib/tfprof/tools/tfprof/internal/testdata/graph.pbtxt
vendored
Normal file
636
tensorflow/contrib/tfprof/tools/tfprof/internal/testdata/graph.pbtxt
vendored
Normal file
@ -0,0 +1,636 @@
|
||||
node {
|
||||
name: "zeros"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 6
|
||||
}
|
||||
dim {
|
||||
size: 6
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW"
|
||||
op: "Variable"
|
||||
attr {
|
||||
key: "container"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 6
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shared_name"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW/Initializer/random_normal/shape"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 4
|
||||
}
|
||||
}
|
||||
tensor_content: "\003\000\000\000\003\000\000\000\003\000\000\000\006\000\000\000"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW/Initializer/random_normal/mean"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW/Initializer/random_normal/stddev"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 0.0010000000475
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW/Initializer/random_normal/RandomStandardNormal"
|
||||
op: "RandomStandardNormal"
|
||||
input: "DW/Initializer/random_normal/shape"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "seed"
|
||||
value {
|
||||
i: 87654321
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "seed2"
|
||||
value {
|
||||
i: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW/Initializer/random_normal/mul"
|
||||
op: "Mul"
|
||||
input: "DW/Initializer/random_normal/RandomStandardNormal"
|
||||
input: "DW/Initializer/random_normal/stddev"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW/Initializer/random_normal"
|
||||
op: "Add"
|
||||
input: "DW/Initializer/random_normal/mul"
|
||||
input: "DW/Initializer/random_normal/mean"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW/Assign"
|
||||
op: "Assign"
|
||||
input: "DW"
|
||||
input: "DW/Initializer/random_normal"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_locking"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "validate_shape"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW/read"
|
||||
op: "Identity"
|
||||
input: "DW"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Conv2D"
|
||||
op: "Conv2D"
|
||||
input: "zeros"
|
||||
input: "DW/read"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "data_format"
|
||||
value {
|
||||
s: "NHWC"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "padding"
|
||||
value {
|
||||
s: "SAME"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "strides"
|
||||
value {
|
||||
list {
|
||||
i: 1
|
||||
i: 2
|
||||
i: 2
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_cudnn_on_gpu"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2"
|
||||
op: "Variable"
|
||||
attr {
|
||||
key: "container"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 6
|
||||
}
|
||||
dim {
|
||||
size: 12
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shared_name"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2/Initializer/random_normal/shape"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW2"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 4
|
||||
}
|
||||
}
|
||||
tensor_content: "\002\000\000\000\002\000\000\000\006\000\000\000\014\000\000\000"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2/Initializer/random_normal/mean"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW2"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2/Initializer/random_normal/stddev"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW2"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 0.0010000000475
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2/Initializer/random_normal/RandomStandardNormal"
|
||||
op: "RandomStandardNormal"
|
||||
input: "DW2/Initializer/random_normal/shape"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW2"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "seed"
|
||||
value {
|
||||
i: 87654321
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "seed2"
|
||||
value {
|
||||
i: 15
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2/Initializer/random_normal/mul"
|
||||
op: "Mul"
|
||||
input: "DW2/Initializer/random_normal/RandomStandardNormal"
|
||||
input: "DW2/Initializer/random_normal/stddev"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW2"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2/Initializer/random_normal"
|
||||
op: "Add"
|
||||
input: "DW2/Initializer/random_normal/mul"
|
||||
input: "DW2/Initializer/random_normal/mean"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW2"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2/Assign"
|
||||
op: "Assign"
|
||||
input: "DW2"
|
||||
input: "DW2/Initializer/random_normal"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW2"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_locking"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "validate_shape"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "DW2/read"
|
||||
op: "Identity"
|
||||
input: "DW2"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@DW2"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Conv2D_1"
|
||||
op: "Conv2D"
|
||||
input: "Conv2D"
|
||||
input: "DW2/read"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "data_format"
|
||||
value {
|
||||
s: "NHWC"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "padding"
|
||||
value {
|
||||
s: "SAME"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "strides"
|
||||
value {
|
||||
list {
|
||||
i: 1
|
||||
i: 2
|
||||
i: 2
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "use_cudnn_on_gpu"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 13
|
||||
}
|
22
tensorflow/contrib/tfprof/tools/tfprof/internal/testdata/run_meta
vendored
Normal file
22
tensorflow/contrib/tfprof/tools/tfprof/internal/testdata/run_meta
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
|
||||
<EFBFBD>
|
||||
<EFBFBD>
|
||||
%/job:localhost/replica:0/task:0/cpu:0:
|
||||
_SOURCEû¡ˆ§·†Ï (2
|
||||
cpuB_SOURCE = NoOp()H塈§·†Ïa
|
||||
zeros”¢ˆ§·†Ï (2
|
||||
cpu:(&"àcpu0€ýèÉöûBzeros = Const()H<>¢ˆ§·†Ï^
|
||||
DW<10>¢ˆ§·†Ï (2
|
||||
cpu:(&"ˆcpu0à©€ ìûBDW = Variable()H‹¢ˆ§·†Ï`
|
||||
DW2Ÿ¢ˆ§·†Ï (2
|
||||
cpu:(&"€ cpu0 Ÿ€àëûBDW2 = Variable()H‹¢ˆ§·†Ïj
|
||||
DW/read±¢ˆ§·†Ï (2
|
||||
cpu:(&"ˆcpu0à©€ ìûBDW/read = Identity(DW)H¥¢ˆ§·†Ïm
|
||||
DW2/read¸¢ˆ§·†Ï (2
|
||||
cpu:(&"€ cpu0 Ÿ€àëûBDW2/read = Identity(DW2)H§¢ˆ§·†Ïs
|
||||
Conv2D¹¢ˆ§·†Ï P(U2
|
||||
cpu°:(&"°cpu0à«€àìûBConv2D = Conv2D(zeros, DW/read)H¶¢ˆ§·†Ï{
|
||||
Conv2D_1’£ˆ§·†Ï (2
|
||||
cpu€:(&"€cpu0฀àìûB#Conv2D_1 = Conv2D(Conv2D, DW2/read)HŽ£ˆ§·†Ï6
|
||||
_SINK³£ˆ§·†Ï (2
|
||||
cpuB_SINK = NoOp()H£ˆ§·†Ï
|
9
tensorflow/contrib/tfprof/tools/tfprof/internal/testdata/tfprof_log
vendored
Normal file
9
tensorflow/contrib/tfprof/tools/tfprof/internal/testdata/tfprof_log
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
|
||||
Conv2D_1€$
|
||||
|
||||
DW2_trainable_variables
|
||||
|
||||
DW_trainable_variables
|
||||
|
||||
Conv2DÈ-
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_CONSTANTS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_CONSTANTS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
// Op name of root of everything. Aggregates all stats.
|
||||
static const char* const kTFProfRoot = "_TFProfRoot";
|
||||
// Op type for nodes that doesn't represent a physical node in the
|
||||
// TensorFlow model. Only exist as a placehold to aggregate children.
|
||||
// For example, kTFProfRoot belongs to this type.
|
||||
static const char* const kTFGraphParent = "_TFGraphParent";
|
||||
static const char* const kTFScopeParent = "_kTFScopeParent";
|
||||
// Op type for tf.trainable_variables().
|
||||
static const char* const kTrainableVarType = "_trainable_variables";
|
||||
// Op type for tensors in the checkpoint file.
|
||||
static const char* const kCkptVarType = "_checkpoint_variables";
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_CONSTANTS_H_
|
222
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_graph.cc
Normal file
222
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_graph.cc
Normal file
@ -0,0 +1,222 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_graph.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_constants.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_tensor.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
GraphNode* TFGraph::CreateParentNode(const string& name) {
|
||||
node_defs_.push_back(std::unique_ptr<NodeDef>(new NodeDef()));
|
||||
node_defs_.back()->set_name(name);
|
||||
node_defs_.back()->set_op(kTFGraphParent);
|
||||
parent_nodes_[name] =
|
||||
std::unique_ptr<TFNode>(new TFNode(node_defs_.back().get()));
|
||||
nodes_map_[name] =
|
||||
std::unique_ptr<GraphNode>(new GraphNode(parent_nodes_[name].get()));
|
||||
return nodes_map_[name].get();
|
||||
}
|
||||
|
||||
void TFGraph::AddNode(TFNode* node) {
|
||||
string name = node->node_def()->name();
|
||||
nodes_map_[name] = std::unique_ptr<GraphNode>(new GraphNode(node));
|
||||
}
|
||||
|
||||
void TFGraph::Build() {
|
||||
if (!roots_.empty()) return;
|
||||
|
||||
std::set<string> nonroots;
|
||||
// Filter out the root nodes (node not input of any other node).
|
||||
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
|
||||
GraphNode* node = it->second.get();
|
||||
const std::map<string, TFNode*>& inputs = node->node->inputs();
|
||||
for (auto inputs_it = inputs.cbegin(); inputs_it != inputs.cend();
|
||||
inputs_it++) {
|
||||
nonroots.insert(inputs_it->first);
|
||||
auto child_it = nodes_map_.find(inputs_it->first);
|
||||
if (child_it != nodes_map_.end()) {
|
||||
node->children.push_back(child_it->second.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
|
||||
if (nonroots.find(it->first) == nonroots.end()) {
|
||||
roots_.push_back(it->second.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const ShowNode* TFGraph::ShowInternal(const Options& opts) {
|
||||
// Search the nodes to start from.
|
||||
std::vector<GraphNode*> roots = roots_;
|
||||
if (opts.start_name_regexes.size() != 1 ||
|
||||
opts.start_name_regexes[0] != ".*") {
|
||||
std::set<string> visited;
|
||||
roots = SearchRoot(roots, opts.start_name_regexes, &visited);
|
||||
}
|
||||
|
||||
GraphNode* root = CreateParentNode(kTFProfRoot);
|
||||
root->children.assign(roots.begin(), roots.end());
|
||||
|
||||
std::map<string, int64> account_visits;
|
||||
Account({root}, opts, &account_visits);
|
||||
|
||||
if (opts.viz) {
|
||||
printf("Visualizing feature disabled...\n");
|
||||
}
|
||||
std::set<string> visits;
|
||||
return PrintGraph({root}, opts, 1, 0, 0, &visits)[0];
|
||||
}
|
||||
|
||||
std::vector<GraphNode*> TFGraph::SearchRoot(
|
||||
const std::vector<GraphNode*>& roots, const std::vector<string>& regexes,
|
||||
std::set<string>* visited) {
|
||||
std::vector<GraphNode*> res;
|
||||
if (roots.empty()) {
|
||||
return res;
|
||||
}
|
||||
for (GraphNode* root : roots) {
|
||||
if (visited->find(root->name()) != visited->end()) continue;
|
||||
visited->insert(root->name());
|
||||
// If the parent is a start point, don't search its children.
|
||||
// Note that its children can still be added as start node through
|
||||
// another route.
|
||||
bool match_start_node = false;
|
||||
for (const string& regex : regexes) {
|
||||
if (RE2::FullMatch(root->name(), regex)) {
|
||||
res.push_back(root);
|
||||
match_start_node = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_start_node) {
|
||||
continue;
|
||||
}
|
||||
std::vector<GraphNode*> nroot =
|
||||
SearchRoot(root->children, regexes, visited);
|
||||
res.insert(res.end(), nroot.begin(), nroot.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<GraphNode*> TFGraph::PrintGraph(const std::vector<GraphNode*> roots,
|
||||
const Options& opts, int depth,
|
||||
int hidden, int last_ident,
|
||||
std::set<string>* visits) {
|
||||
std::vector<GraphNode*> show_nodes;
|
||||
|
||||
for (GraphNode* node : roots) {
|
||||
if (visits->find(node->name()) != visits->end()) continue;
|
||||
visits->insert(node->name());
|
||||
|
||||
int nhidden = hidden;
|
||||
int nlast_ident = last_ident;
|
||||
bool show = ShouldShow(node, opts, depth);
|
||||
if (show) {
|
||||
node->formatted_str.clear();
|
||||
if (opts.account_displayed_op_only) {
|
||||
node->ResetTotalStats();
|
||||
node->AddSelfToTotalStats();
|
||||
}
|
||||
nhidden = 0;
|
||||
nlast_ident = (hidden && opts.select.find(kShown[4]) != opts.select.end()
|
||||
? last_ident + 4
|
||||
: last_ident + 2);
|
||||
} else {
|
||||
++nhidden;
|
||||
}
|
||||
|
||||
std::vector<GraphNode*> show_cnodes;
|
||||
if (!ShouldTrim(node, opts.trim_name_regexes)) {
|
||||
show_cnodes = PrintGraph(node->children, opts, depth + 1, nhidden,
|
||||
nlast_ident, visits);
|
||||
}
|
||||
if (show) {
|
||||
show_cnodes = SortNodes(show_cnodes, opts);
|
||||
string children_str;
|
||||
for (GraphNode* sc : show_cnodes) {
|
||||
children_str += sc->formatted_str;
|
||||
node->mutable_proto()->add_children()->MergeFrom(sc->proto());
|
||||
if (opts.account_displayed_op_only) {
|
||||
node->AggregateTotalStats(sc);
|
||||
}
|
||||
}
|
||||
if (hidden && opts.select.find(kShown[4]) != opts.select.end()) {
|
||||
node->formatted_str = strings::Printf(
|
||||
"%s...hidden %d...\n", string(last_ident, ' ').c_str(), hidden);
|
||||
node->formatted_str +=
|
||||
strings::Printf(" %s%s\n", string(last_ident, ' ').c_str(),
|
||||
node->Format(opts).c_str());
|
||||
} else {
|
||||
node->formatted_str =
|
||||
strings::Printf("%s%s\n", string(last_ident, ' ').c_str(),
|
||||
node->Format(opts).c_str());
|
||||
}
|
||||
if (opts.select.find(kShown[5]) != opts.select.end()) {
|
||||
std::unique_ptr<TFProfTensor> tfprof_tensor;
|
||||
if (LookUpCheckPoint(node->name(), &tfprof_tensor)) {
|
||||
string value_str;
|
||||
tfprof_tensor->Display(&value_str,
|
||||
node->mutable_proto()->mutable_tensor_value());
|
||||
node->formatted_str += value_str;
|
||||
}
|
||||
}
|
||||
|
||||
node->formatted_str += children_str;
|
||||
show_nodes.push_back(node);
|
||||
} else {
|
||||
show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
|
||||
show_cnodes.end());
|
||||
}
|
||||
}
|
||||
return show_nodes;
|
||||
}
|
||||
|
||||
void TFGraph::Account(const std::vector<GraphNode*>& roots, const Options& opts,
|
||||
std::map<string, int64>* visits) {
|
||||
if (roots.empty()) return;
|
||||
|
||||
for (GraphNode* node : roots) {
|
||||
if (visits->find(node->name()) != visits->end()) continue;
|
||||
(*visits)[node->name()] = 1;
|
||||
node->ResetTotalStats();
|
||||
// Depth-firsth.
|
||||
Account(node->children, opts, visits);
|
||||
|
||||
node->account = ShouldAccount(node, opts);
|
||||
if (node->account) {
|
||||
node->AddSelfToTotalStats();
|
||||
}
|
||||
// Aggregate its children stats.
|
||||
for (GraphNode* c : node->children) {
|
||||
// A node can be visited from multiple parents. Only account once.
|
||||
// "visits==1" is when the node is visited through depth-first search.
|
||||
(*visits)[c->name()] += 1;
|
||||
if ((*visits)[c->name()] > 2) continue;
|
||||
|
||||
node->AggregateTotalStats(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
116
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_graph.h
Normal file
116
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_graph.h
Normal file
@ -0,0 +1,116 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Build a graph structure based on op inputs/outputs. The graph is a directed
|
||||
// acyclic graph pointing *from outputs to inputs*.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_GRAPH_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_GRAPH_H_
|
||||
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_node.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_show.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_utils.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/tfprof_output.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
class GraphNode : public ShowNode {
|
||||
public:
|
||||
explicit GraphNode(TFNode* node) : ShowNode(node) {
|
||||
mutable_proto()->set_inputs(node->inputs().size());
|
||||
mutable_proto()->set_total_inputs(0);
|
||||
}
|
||||
|
||||
void AggregateTotalStats(GraphNode* node) {
|
||||
ShowNode::AggregateTotalStats(node);
|
||||
mutable_proto()->set_total_inputs(proto().total_inputs() +
|
||||
node->proto().total_inputs() + 1);
|
||||
}
|
||||
|
||||
void AddSelfToTotalStats() {
|
||||
ShowNode::AddSelfToTotalStats();
|
||||
mutable_proto()->set_total_inputs(proto().total_inputs() +
|
||||
proto().inputs());
|
||||
}
|
||||
|
||||
void ResetTotalStats() {
|
||||
ShowNode::ResetTotalStats();
|
||||
mutable_proto()->set_total_inputs(0);
|
||||
}
|
||||
|
||||
std::vector<GraphNode*> children;
|
||||
};
|
||||
|
||||
// Organize tensorflow ops in a graph structure, pointing from output ops
|
||||
// to input ops.
|
||||
class TFGraph : public TFShow {
|
||||
public:
|
||||
explicit TFGraph(checkpoint::CheckpointReader* ckpt_reader)
|
||||
: TFShow(ckpt_reader) {}
|
||||
~TFGraph() override {}
|
||||
|
||||
void AddNode(TFNode* node) override;
|
||||
|
||||
void Build() override;
|
||||
|
||||
private:
|
||||
const ShowNode* ShowInternal(const Options& opts) override;
|
||||
|
||||
bool ShouldShowIfExtra(ShowNode* node, const Options& opts,
|
||||
int depth) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
GraphNode* CreateParentNode(const string& name);
|
||||
|
||||
std::vector<GraphNode*> SearchRoot(const std::vector<GraphNode*>& roots,
|
||||
const std::vector<string>& regexes,
|
||||
std::set<string>* visited);
|
||||
|
||||
std::vector<GraphNode*> PrintGraph(const std::vector<GraphNode*> roots,
|
||||
const Options& opts, int depth, int hidden,
|
||||
int last_ident, std::set<string>* visits);
|
||||
|
||||
void VisualizeGraph(GraphNode* root, const Options& opts);
|
||||
|
||||
std::vector<GraphNode*> GenerateGraphDot(
|
||||
GraphNode* root, GraphNode* last_shown, const Options& opts, int depth,
|
||||
int hidden, std::set<string>* declared_nodes,
|
||||
std::set<string>* declared_edges, TFProfNode* parent);
|
||||
|
||||
void Account(const std::vector<GraphNode*>& roots, const Options& opts,
|
||||
std::map<string, int64>* visits);
|
||||
|
||||
std::vector<GraphNode*> roots_;
|
||||
std::vector<std::unique_ptr<NodeDef>> node_defs_;
|
||||
std::map<string, std::unique_ptr<TFNode>> parent_nodes_;
|
||||
std::map<string, std::unique_ptr<GraphNode>> nodes_map_;
|
||||
};
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_GRAPH_H_
|
@ -0,0 +1,47 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_node.h"
|
||||
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
void TFNode::AddStepStat(const string& device, const NodeExecStats* step_stat) {
|
||||
if (!device.empty()) {
|
||||
// This might override device from GraphDef.
|
||||
device_ = device;
|
||||
}
|
||||
step_stat_ = step_stat;
|
||||
|
||||
op_start_micros_ = step_stat_->all_start_micros();
|
||||
if (step_stat_->op_end_rel_micros() && step_stat_->op_start_rel_micros()) {
|
||||
op_exec_micros_ =
|
||||
step_stat_->op_end_rel_micros() - step_stat_->op_start_rel_micros();
|
||||
}
|
||||
all_spent_micros_ = step_stat_->all_end_rel_micros();
|
||||
|
||||
for (const auto& output : step_stat_->output()) {
|
||||
if (output.has_tensor_description() &&
|
||||
output.tensor_description().has_allocation_description()) {
|
||||
requested_bytes_ += output.tensor_description()
|
||||
.allocation_description()
|
||||
.requested_bytes();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
106
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_node.h
Normal file
106
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_node.h
Normal file
@ -0,0 +1,106 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
class TFNode {
|
||||
public:
|
||||
TFNode(const NodeDef* node)
|
||||
: node_(node),
|
||||
step_stat_(nullptr),
|
||||
op_start_micros_(0),
|
||||
op_exec_micros_(0),
|
||||
all_spent_micros_(0),
|
||||
requested_bytes_(0),
|
||||
float_ops_(0) {
|
||||
if (!node) return;
|
||||
|
||||
for (const auto& attr : node->attr()) {
|
||||
// TODO(xpan): Also consider _output_shapes.
|
||||
if (attr.first != "shape" || !attr.second.has_shape()) continue;
|
||||
if (!shape_.empty()) {
|
||||
fprintf(stderr, "Found duplicated shapes!\n");
|
||||
continue;
|
||||
}
|
||||
std::vector<int64> shape_vec;
|
||||
for (const auto& d : attr.second.shape().dim()) {
|
||||
shape_vec.push_back(d.size());
|
||||
}
|
||||
update_shape(shape_vec);
|
||||
}
|
||||
op_types_.insert(node->op());
|
||||
device_ = node->device();
|
||||
}
|
||||
|
||||
TFNode() : TFNode(nullptr) {}
|
||||
|
||||
void AddInput(TFNode* input) { inputs_[input->node_def()->name()] = input; }
|
||||
|
||||
void AddOpType(const string& op_type) { op_types_.insert(op_type); }
|
||||
|
||||
void AddStepStat(const string& device, const NodeExecStats* step_stat);
|
||||
|
||||
void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; }
|
||||
|
||||
const NodeDef* node_def() { return node_; }
|
||||
const std::map<string, TFNode*>& inputs() { return inputs_; }
|
||||
int64 op_start_micros() { return op_start_micros_; }
|
||||
int64 op_exec_micros() { return op_exec_micros_; }
|
||||
int64 all_spent_micros() { return all_spent_micros_; }
|
||||
int64 requested_byptes() { return requested_bytes_; }
|
||||
int64 float_ops() { return float_ops_; }
|
||||
string device() { return device_; }
|
||||
const std::set<string>& op_types() { return op_types_; }
|
||||
|
||||
const std::vector<int64>& shape() { return shape_; }
|
||||
void update_shape(const std::vector<int64>& shape) { shape_ = shape; }
|
||||
|
||||
private:
|
||||
std::map<string, TFNode*> inputs_;
|
||||
const NodeDef* node_;
|
||||
const NodeExecStats* step_stat_;
|
||||
|
||||
std::vector<int64> shape_;
|
||||
std::set<string> op_types_;
|
||||
string device_;
|
||||
int64 op_start_micros_;
|
||||
int64 op_exec_micros_;
|
||||
int64 all_spent_micros_;
|
||||
int64 requested_bytes_;
|
||||
int64 float_ops_;
|
||||
};
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_H_
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
|
||||
string Options::ToString() const {
|
||||
const string s = strings::Printf(
|
||||
"%-28s%d\n"
|
||||
"%-28s%lld\n"
|
||||
"%-28s%lld\n"
|
||||
"%-28s%lld\n"
|
||||
"%-28s%lld\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n"
|
||||
"%-28s%s\n",
|
||||
kOptions[0], max_depth, kOptions[1], min_bytes, kOptions[2], min_micros,
|
||||
kOptions[3], min_params, kOptions[4], min_float_ops, kOptions[5],
|
||||
str_util::Join(device_regexes, ",").c_str(), kOptions[6],
|
||||
order_by.c_str(), kOptions[7],
|
||||
str_util::Join(account_type_regexes, ",").c_str(), kOptions[8],
|
||||
str_util::Join(start_name_regexes, ",").c_str(), kOptions[9],
|
||||
str_util::Join(trim_name_regexes, ",").c_str(), kOptions[10],
|
||||
str_util::Join(show_name_regexes, ",").c_str(), kOptions[11],
|
||||
str_util::Join(hide_name_regexes, ",").c_str(), kOptions[12],
|
||||
(account_displayed_op_only ? "true" : "false"), kOptions[13],
|
||||
str_util::Join(select, ",").c_str(), kOptions[14],
|
||||
(viz ? "true" : "false"), kOptions[15], dump_to_file.c_str());
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
119
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h
Normal file
119
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_OPTIONS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_OPTIONS_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
static const char* const kOptions[] = {
|
||||
"-max_depth",
|
||||
"-min_bytes",
|
||||
"-min_micros",
|
||||
"-min_params",
|
||||
"-min_float_ops",
|
||||
"-device_regexes",
|
||||
"-order_by",
|
||||
"-account_type_regexes",
|
||||
"-start_name_regexes",
|
||||
"-trim_name_regexes",
|
||||
"-show_name_regexes",
|
||||
"-hide_name_regexes",
|
||||
"-account_displayed_op_only",
|
||||
"-select",
|
||||
"-viz",
|
||||
"-dump_to_file",
|
||||
};
|
||||
|
||||
static const char* const kOrderBy[] = {
|
||||
"name", "bytes", "micros", "params", "float_ops",
|
||||
};
|
||||
|
||||
// Append Only.
|
||||
static const char* const kShown[] = {
|
||||
"bytes", "micros", "params", "float_ops",
|
||||
"num_hidden_ops", "tensor_value", "device", "op_types",
|
||||
};
|
||||
|
||||
static const char* const kCmds[] = {
|
||||
"scope", "graph", "set", "help",
|
||||
};
|
||||
|
||||
struct Options {
|
||||
public:
|
||||
virtual ~Options() {}
|
||||
Options(int max_depth, tensorflow::int64 min_bytes,
|
||||
tensorflow::int64 min_micros, tensorflow::int64 min_params,
|
||||
tensorflow::int64 min_float_ops,
|
||||
const std::vector<string>& device_regexes, const string& order_by,
|
||||
const std::vector<string>& account_type_regexes,
|
||||
const std::vector<string>& start_name_regexes,
|
||||
const std::vector<string>& trim_name_regexes,
|
||||
const std::vector<string>& show_name_regexes,
|
||||
const std::vector<string>& hide_name_regexes,
|
||||
bool account_displayed_op_only, const std::vector<string>& select,
|
||||
bool viz, const string& dump_to_file = "")
|
||||
: max_depth(max_depth),
|
||||
min_bytes(min_bytes),
|
||||
min_micros(min_micros),
|
||||
min_params(min_params),
|
||||
min_float_ops(min_float_ops),
|
||||
device_regexes(device_regexes),
|
||||
order_by(order_by),
|
||||
account_type_regexes(account_type_regexes),
|
||||
start_name_regexes(start_name_regexes),
|
||||
trim_name_regexes(trim_name_regexes),
|
||||
show_name_regexes(show_name_regexes),
|
||||
hide_name_regexes(hide_name_regexes),
|
||||
account_displayed_op_only(account_displayed_op_only),
|
||||
select(select.begin(), select.end()),
|
||||
viz(viz),
|
||||
dump_to_file(dump_to_file) {}
|
||||
|
||||
string ToString() const;
|
||||
|
||||
int max_depth;
|
||||
tensorflow::int64 min_bytes;
|
||||
tensorflow::int64 min_micros;
|
||||
tensorflow::int64 min_params;
|
||||
tensorflow::int64 min_float_ops;
|
||||
std::vector<string> device_regexes;
|
||||
string order_by;
|
||||
|
||||
std::vector<string> account_type_regexes;
|
||||
std::vector<string> start_name_regexes;
|
||||
std::vector<string> trim_name_regexes;
|
||||
std::vector<string> show_name_regexes;
|
||||
std::vector<string> hide_name_regexes;
|
||||
bool account_displayed_op_only;
|
||||
|
||||
std::set<string> select;
|
||||
bool viz;
|
||||
string dump_to_file;
|
||||
};
|
||||
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_OPTIONS_H_
|
191
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_scope.cc
Normal file
191
tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_scope.cc
Normal file
@ -0,0 +1,191 @@
|
||||
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_scope.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_constants.h"
|
||||
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfprof {
|
||||
ScopeNode* TFScope::CreateParentNode(const string& name) {
|
||||
if (nodes_map_.find(name) != nodes_map_.end()) {
|
||||
return nodes_map_[name].get();
|
||||
}
|
||||
node_defs_.push_back(std::unique_ptr<NodeDef>(new NodeDef()));
|
||||
node_defs_.back()->set_name(name);
|
||||
node_defs_.back()->set_op(kTFScopeParent);
|
||||
parent_nodes_[name] =
|
||||
std::unique_ptr<TFNode>(new TFNode(node_defs_.back().get()));
|
||||
nodes_map_[name] =
|
||||
std::unique_ptr<ScopeNode>(new ScopeNode(parent_nodes_[name].get()));
|
||||
return nodes_map_[name].get();
|
||||
}
|
||||
|
||||
void TFScope::AddNode(TFNode* node) {
|
||||
string name = node->node_def()->name();
|
||||
if (nodes_map_.find(node->node_def()->name()) == nodes_map_.end()) {
|
||||
nodes_map_[name] = std::unique_ptr<ScopeNode>(new ScopeNode(node));
|
||||
}
|
||||
|
||||
auto last_slash = name.find_last_of("/");
|
||||
while (last_slash != name.npos) {
|
||||
name = name.substr(0, last_slash);
|
||||
if (nodes_map_.find(name) == nodes_map_.end()) {
|
||||
CHECK(CreateParentNode(name));
|
||||
}
|
||||
last_slash = name.find_last_of("/");
|
||||
}
|
||||
}
|
||||
|
||||
void TFScope::Build() {
|
||||
if (!roots_.empty()) return;
|
||||
// Found roots, which are nodes without "/".
|
||||
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
|
||||
ScopeNode* node = it->second.get();
|
||||
auto last_slash = node->name().find_last_of("/");
|
||||
if (last_slash == string::npos) {
|
||||
roots_.push_back(node);
|
||||
} else {
|
||||
const string prefix = node->name().substr(0, last_slash);
|
||||
nodes_map_[prefix]->children.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const ShowNode* TFScope::ShowInternal(const Options& opts) {
|
||||
// Search from roots recursively to find start node, if start_name_regexes
|
||||
// is specified.
|
||||
std::vector<ScopeNode*> roots = roots_;
|
||||
if (opts.start_name_regexes.size() != 1 ||
|
||||
opts.start_name_regexes[0] != ".*") {
|
||||
roots = SearchRoot(roots, opts.start_name_regexes);
|
||||
}
|
||||
|
||||
ScopeNode* root = CreateParentNode(kTFProfRoot);
|
||||
root->children.assign(roots.begin(), roots.end());
|
||||
Account({root}, opts);
|
||||
|
||||
root = PrintScope({root}, opts, 1, 0)[0];
|
||||
return root;
|
||||
}
|
||||
|
||||
std::vector<ScopeNode*> TFScope::SearchRoot(
|
||||
std::vector<ScopeNode*> roots, const std::vector<string>& regexes) {
|
||||
std::vector<ScopeNode*> res;
|
||||
if (roots.empty()) {
|
||||
return res;
|
||||
}
|
||||
for (ScopeNode* root : roots) {
|
||||
bool match_start_node = false;
|
||||
for (const string& regex : regexes) {
|
||||
if (RE2::FullMatch(root->name(), regex)) {
|
||||
res.push_back(root);
|
||||
match_start_node = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_start_node) {
|
||||
// Found a start node at this branch, no need to continue.
|
||||
continue;
|
||||
}
|
||||
std::vector<ScopeNode*> nroots = SearchRoot(root->children, regexes);
|
||||
res.insert(res.end(), nroots.begin(), nroots.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<ScopeNode*> TFScope::PrintScope(const std::vector<ScopeNode*> roots,
|
||||
const Options& opts, int depth,
|
||||
int last_ident) {
|
||||
std::vector<ScopeNode*> show_nodes;
|
||||
|
||||
for (ScopeNode* node : roots) {
|
||||
int nlast_ident = last_ident;
|
||||
bool show = ShouldShow(node, opts, depth);
|
||||
if (show) {
|
||||
node->formatted_str.clear();
|
||||
if (opts.account_displayed_op_only) {
|
||||
node->ResetTotalStats();
|
||||
node->AddSelfToTotalStats();
|
||||
}
|
||||
nlast_ident += 2;
|
||||
}
|
||||
|
||||
std::vector<ScopeNode*> show_cnodes;
|
||||
if (!ShouldTrim(node, opts.trim_name_regexes)) {
|
||||
show_cnodes = PrintScope(node->children, opts, depth + 1, nlast_ident);
|
||||
}
|
||||
if (show) {
|
||||
show_cnodes = SortNodes(show_cnodes, opts);
|
||||
string children_str;
|
||||
for (ScopeNode* sc : show_cnodes) {
|
||||
children_str += sc->formatted_str;
|
||||
node->mutable_proto()->add_children()->MergeFrom(sc->proto());
|
||||
if (opts.account_displayed_op_only) {
|
||||
node->AggregateTotalStats(sc);
|
||||
}
|
||||
}
|
||||
|
||||
node->formatted_str =
|
||||
strings::Printf("%s%s\n", string(last_ident, ' ').c_str(),
|
||||
node->Format(opts).c_str());
|
||||
|
||||
if (opts.select.find(kShown[5]) != opts.select.end()) {
|
||||
std::unique_ptr<TFProfTensor> tfprof_tensor;
|
||||
if (LookUpCheckPoint(node->name(), &tfprof_tensor)) {
|
||||
string value_str;
|
||||
tfprof_tensor->Display(&value_str,
|
||||
node->mutable_proto()->mutable_tensor_value());
|
||||
node->formatted_str += value_str;
|
||||
}
|
||||
}
|
||||
|
||||
node->formatted_str += children_str;
|
||||
show_nodes.push_back(node);
|
||||
} else {
|
||||
show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
|
||||
show_cnodes.end());
|
||||
}
|
||||
}
|
||||
return show_nodes;
|
||||
}
|
||||
|
||||
void TFScope::Account(const std::vector<ScopeNode*>& roots,
|
||||
const Options& opts) {
|
||||
if (roots.empty()) return;
|
||||
|
||||
for (ScopeNode* node : roots) {
|
||||
node->ResetTotalStats();
|
||||
Account(node->children, opts);
|
||||
|
||||
node->account = ShouldAccount(node, opts);
|
||||
if (node->account) {
|
||||
node->AddSelfToTotalStats();
|
||||
}
|
||||
for (ScopeNode* c : node->children) {
|
||||
node->AggregateTotalStats(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace tfprof
|
||||
} // namespace tensorflow
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user