Merge and fix workspace.bzl for cmake, makefile use.

This commit is contained in:
Martin Wicke 2016-09-23 10:01:34 -07:00
commit b992ff69e2
399 changed files with 20749 additions and 4425 deletions
RELEASE.mdavro.BUILDboost.BUILDbzip2.BUILDeigen.BUILDfarmhash.BUILDgif.BUILDgmock.BUILDgrpc.BUILDjpeg.BUILDjsoncpp.BUILDlinenoise.BUILDnanopb.BUILDpng.BUILDsix.BUILD
tensorflow
BUILD
c
cc
contrib

View File

@ -17,6 +17,9 @@
instead of graph.proto.
* ops.NoGradient was renamed ops.NotDifferentiable. ops.NoGradient will
be removed soon.
* dot.h / DotGraph was removed (it was an early analysis tool prior
to TensorBoard, no longer that useful). It remains in history
should someone find the code useful.
# Release 0.10.0

View File

@ -2,21 +2,19 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
prefix_dir = "avro-cpp-1.8.0"
cc_library(
name = "avrocpp",
srcs = glob(
[
prefix_dir + "/impl/**/*.cc",
prefix_dir + "/impl/**/*.hh",
"impl/**/*.cc",
"impl/**/*.hh",
],
exclude = [
prefix_dir + "/impl/avrogencpp.cc",
"impl/avrogencpp.cc",
],
),
hdrs = glob([prefix_dir + "/api/**/*.hh"]),
includes = [prefix_dir + "/api"],
hdrs = glob(["api/**/*.hh"]),
includes = ["api"],
deps = [
"@boost_archive//:boost",
"@boost_archive//:filesystem",
@ -27,7 +25,7 @@ cc_library(
cc_binary(
name = "avrogencpp",
srcs = [prefix_dir + "/impl/avrogencpp.cc"],
srcs = ["impl/avrogencpp.cc"],
deps = [
":avrocpp",
"@boost_archive//:program_options",

View File

@ -10,21 +10,19 @@ package(default_visibility = ["@avro_archive//:__subpackages__"])
licenses(["notice"]) # Boost software license
prefix_dir = "boost_1_61_0"
cc_library(
name = "boost",
hdrs = glob([
prefix_dir + "/boost/**/*.hpp",
prefix_dir + "/boost/**/*.h",
prefix_dir + "/boost/**/*.ipp",
"boost/**/*.hpp",
"boost/**/*.h",
"boost/**/*.ipp",
]),
includes = [prefix_dir],
includes = ["."],
)
cc_library(
name = "filesystem",
srcs = glob([prefix_dir + "/libs/filesystem/src/*.cpp"]),
srcs = glob(["libs/filesystem/src/*.cpp"]),
deps = [
":boost",
":system",
@ -33,7 +31,7 @@ cc_library(
cc_library(
name = "iostreams",
srcs = glob([prefix_dir + "/libs/iostreams/src/*.cpp"]),
srcs = glob(["libs/iostreams/src/*.cpp"]),
deps = [
":boost",
"@bzip2_archive//:bz2lib",
@ -43,16 +41,12 @@ cc_library(
cc_library(
name = "program_options",
srcs = glob([prefix_dir + "/libs/program_options/src/*.cpp"]),
deps = [
":boost",
],
srcs = glob(["libs/program_options/src/*.cpp"]),
deps = [":boost"],
)
cc_library(
name = "system",
srcs = glob([prefix_dir + "/libs/system/src/*.cpp"]),
deps = [
":boost",
],
srcs = glob(["libs/system/src/*.cpp"]),
deps = [":boost"],
)

View File

@ -2,35 +2,27 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # BSD derivative
prefix_dir = "bzip2-1.0.6"
BZ2LIB_SRCS = [
# these are in the same order as their corresponding .o files are in OBJS in
# Makefile (rather than lexicographic order) for easy comparison (that they
# are identical).
"blocksort.c",
"huffman.c",
"crctable.c",
"randtable.c",
"compress.c",
"decompress.c",
"bzlib.c",
]
cc_library(
name = "bz2lib",
srcs = [prefix_dir + "/" + source for source in BZ2LIB_SRCS] +
[prefix_dir + "/bzlib_private.h"],
hdrs = [prefix_dir + "/bzlib.h"],
includes = [prefix_dir],
srcs = [
# These are in the same order as their corresponding .o files are in
# OBJS in Makefile (rather than lexicographic order) for easy
# comparison (that they are identical.)
"blocksort.c",
"huffman.c",
"crctable.c",
"randtable.c",
"compress.c",
"decompress.c",
"bzlib.c",
"bzlib_private.h",
],
hdrs = ["bzlib.h"],
includes = ["."],
)
cc_binary(
name = "bzip2",
srcs = [
"bzip2.c",
],
deps = [
":bz2lib",
],
srcs = ["bzip2.c"],
deps = [":bz2lib"],
)

View File

@ -1,8 +1,70 @@
package(default_visibility = ["//visibility:public"])
# Description:
# Eigen is a C++ template library for linear algebra: vectors,
# matrices, and related algorithms.
licenses([
# Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code.
# We've taken special care to not reference any restricted code.
"reciprocal", # MPL2
"notice", # Portions BSD
])
# License-restricted (i.e. not reciprocal or notice) files inside Eigen/...
EIGEN_RESTRICTED_FILES = [
"Eigen/src/OrderingMethods/Amd.h",
"Eigen/src/SparseCholesky/**",
]
# Notable transitive dependencies of restricted files inside Eigen/...
EIGEN_RESTRICTED_DEPS = [
"Eigen/Eigen",
"Eigen/IterativeLinearSolvers",
"Eigen/MetisSupport",
"Eigen/Sparse",
"Eigen/SparseCholesky",
"Eigen/SparseLU",
]
# Note: unsupported/Eigen is unsupported and might go away at any time.
EIGEN_FILES = [
"Eigen/**",
"unsupported/Eigen/CXX11/**",
"unsupported/Eigen/FFT",
"unsupported/Eigen/KroneckerProduct",
"unsupported/Eigen/src/FFT/**",
"unsupported/Eigen/src/KroneckerProduct/**",
"unsupported/Eigen/MatrixFunctions",
"unsupported/Eigen/SpecialFunctions",
"unsupported/Eigen/src/SpecialFunctions/**",
]
# List of files picked up by glob but actually part of another target.
EIGEN_EXCLUDE_FILES = [
"Eigen/src/Core/arch/AVX/PacketMathGoogleTest.cc",
]
# Files known to be under MPL2 license.
EIGEN_MPL2_HEADER_FILES = glob(
EIGEN_FILES,
exclude = EIGEN_EXCLUDE_FILES +
EIGEN_RESTRICTED_FILES +
EIGEN_RESTRICTED_DEPS + [
# Guarantees any file missed by excludes above will not compile.
"Eigen/src/Core/util/NonMPL2.h",
"Eigen/**/CMakeLists.txt",
],
)
cc_library(
name = "eigen",
hdrs = glob(["**/*.h", "unsupported/Eigen/*", "unsupported/Eigen/CXX11/*", "Eigen/*"]),
includes = [ '.' ],
hdrs = EIGEN_MPL2_HEADER_FILES,
defines = [
# This define (mostly) guarantees we don't link any problematic
# code. We use it, but we do not rely on it, as evidenced above.
"EIGEN_MPL2_ONLY",
# TODO(jart): Use EIGEN_USE_NONBLOCKING_THREAD_POOL but first add an
# eigen_initialize.cc file and alwayslink=1.
],
includes = ["."],
visibility = ["//visibility:public"],
)

View File

@ -1,21 +1,9 @@
package(default_visibility = ["//visibility:public"])
prefix_dir = "farmhash-34c13ddfab0e35422f4c3979f360635a8c050260"
genrule(
name = "configure",
srcs = glob(
["**/*"],
exclude = [prefix_dir + "/config.h"],
),
outs = [prefix_dir + "/config.h"],
cmd = "pushd external/farmhash_archive/%s; workdir=$$(mktemp -d -t tmp.XXXXXXXXXX); cp -a * $$workdir; pushd $$workdir; ./configure; popd; popd; cp $$workdir/config.h $(@D); rm -rf $$workdir;" % prefix_dir,
)
licenses(["notice"]) # MIT
cc_library(
name = "farmhash",
srcs = [prefix_dir + "/src/farmhash.cc"],
hdrs = [prefix_dir + "/src/farmhash.h"] + [":configure"],
includes = [prefix_dir],
visibility = ["//visibility:public"]
srcs = ["farmhash.cc"],
hdrs = ["farmhash.h"],
includes = ["."],
visibility = ["//visibility:public"],
)

View File

@ -1,65 +1,44 @@
SOURCES = [
"dgif_lib.c",
"egif_lib.c",
"gif_font.c",
"gif_hash.c",
"gifalloc.c",
"openbsd-reallocarray.c",
"gif_err.c",
"quantize.c",
]
# Description:
# A library for decoding and encoding GIF images
HEADERS = [
"gif_hash.h",
"gif_lib.h",
"gif_lib_private.h",
]
config_setting(
name = "windows",
values = {
"cpu": "x64_windows_msvc",
},
visibility = ["//visibility:public"],
)
prefix_dir = "giflib-5.1.4/lib"
prefix_dir_windows = "windows/giflib-5.1.4/lib"
genrule(
name = "srcs_without_unistd",
srcs = [prefix_dir + "/" + source for source in SOURCES],
outs = [prefix_dir_windows + "/" + source for source in SOURCES],
cmd = "for f in $(SRCS); do " +
" sed 's/#include <unistd.h>//g' $$f > $(@D)/%s/$$(basename $$f);" % prefix_dir_windows +
"done",
)
genrule(
name = "hdrs_without_unistd",
srcs = [prefix_dir + "/" + hdrs for hdrs in HEADERS],
outs = [prefix_dir_windows + "/" + hdrs for hdrs in HEADERS],
cmd = "for f in $(SRCS); do " +
" sed 's/#include <unistd.h>//g' $$f > $(@D)/%s/$$(basename $$f);" % prefix_dir_windows +
"done",
)
licenses(["notice"]) # MIT
cc_library(
name = "gif",
srcs = select({
"//conditions:default" : [prefix_dir + "/" + source for source in SOURCES],
":windows" : [":srcs_without_unistd"],
}),
hdrs = select({
"//conditions:default" : [prefix_dir + "/" + hdrs for hdrs in HEADERS],
":windows" : [":hdrs_without_unistd"],
}),
includes = select({
"//conditions:default" : [prefix_dir],
":windows" : [prefix_dir_windows],
}),
defines = [
"HAVE_CONFIG_H",
srcs = [
"dgif_lib.c",
"egif_lib.c",
"gif_err.c",
"gif_font.c",
"gif_hash.c",
"gif_hash.h",
"gif_lib_private.h",
"gifalloc.c",
"openbsd-reallocarray.c",
"quantize.c",
],
hdrs = ["gif_lib.h"],
includes = ["."],
visibility = ["//visibility:public"],
deps = select({
":windows": [":windows_polyfill"],
"//conditions:default": [],
}),
)
cc_library(
name = "windows_polyfill",
hdrs = ["windows/unistd.h"],
includes = ["windows"],
)
genrule(
name = "windows_unistd_h",
outs = ["windows/unistd.h"],
cmd = "touch $@",
)
config_setting(
name = "windows",
values = {"cpu": "x64_windows_msvc"},
)

View File

@ -1,19 +1,25 @@
# Description:
# Google C++ Mocking Framework, a library for creating and using C++
# mock classes.
licenses(["notice"]) # 3-clause BSD
cc_library(
name = "gtest",
srcs = [
"gmock-1.7.0/gtest/src/gtest-all.cc",
"gmock-1.7.0/src/gmock-all.cc",
"gtest/src/gtest-all.cc",
"src/gmock-all.cc",
],
hdrs = glob([
"gmock-1.7.0/**/*.h",
"gmock-1.7.0/gtest/src/*.cc",
"gmock-1.7.0/src/*.cc",
"**/*.h",
"gtest/src/*.cc",
"src/*.cc",
]),
includes = [
"gmock-1.7.0",
"gmock-1.7.0/gtest",
"gmock-1.7.0/gtest/include",
"gmock-1.7.0/include",
".",
"gtest",
"gtest/include",
"include",
],
linkopts = ["-pthread"],
visibility = ["//visibility:public"],
@ -21,7 +27,7 @@ cc_library(
cc_library(
name = "gtest_main",
srcs = ["gmock-1.7.0/src/gmock_main.cc"],
srcs = ["src/gmock_main.cc"],
linkopts = ["-pthread"],
visibility = ["//visibility:public"],
deps = [":gtest"],

View File

@ -3,6 +3,7 @@
# ...with small modifications to fix the build rules for :grpc++_unsecure.
#
# TODO(mrry): Upstream these fixes back to the gRPC repository.
# TODO(jart): Fix nanopb's BUILD file. Fix grpc BUILD file.
# GRPC Bazel BUILD file.
# This currently builds C, C++ and Objective-C code.
@ -44,9 +45,26 @@ licenses(["notice"]) # 3-clause BSD
package(default_visibility = ["//visibility:public"])
genrule(
name = "pb_h",
outs = ["third_party/nanopb/pb.h"],
cmd = "echo '#include <pb.h>' >$@",
visibility = ["//visibility:private"],
)
genrule(
name = "pb_decode_h",
outs = ["third_party/nanopb/pb_decode.h"],
cmd = "echo '#include <pb_decode.h>' >$@",
visibility = ["//visibility:private"],
)
genrule(
name = "pb_encode_h",
outs = ["third_party/nanopb/pb_encode.h"],
cmd = "echo '#include <pb_encode.h>' >$@",
visibility = ["//visibility:private"],
)
cc_library(
name = "gpr",
@ -499,6 +517,9 @@ cc_library(
"src/core/ext/census/placeholders.c",
"src/core/ext/census/tracing.c",
"src/core/plugin_registry/grpc_plugin_registry.c",
"third_party/nanopb/pb.h",
"third_party/nanopb/pb_decode.h",
"third_party/nanopb/pb_encode.h",
],
hdrs = [
"include/grpc/byte_buffer.h",
@ -856,6 +877,9 @@ cc_library(
"src/core/lib/tsi/ssl_transport_security.c",
"src/core/lib/tsi/transport_security.c",
"src/core/plugin_registry/grpc_cronet_plugin_registry.c",
"third_party/nanopb/pb.h",
"third_party/nanopb/pb_decode.h",
"third_party/nanopb/pb_encode.h",
],
hdrs = [
"include/grpc/byte_buffer.h",
@ -1185,6 +1209,9 @@ cc_library(
"src/core/ext/census/placeholders.c",
"src/core/ext/census/tracing.c",
"src/core/plugin_registry/grpc_unsecure_plugin_registry.c",
"third_party/nanopb/pb.h",
"third_party/nanopb/pb_decode.h",
"third_party/nanopb/pb_encode.h",
],
hdrs = [
"include/grpc/byte_buffer.h",
@ -2313,6 +2340,9 @@ objc_library(
"src/core/ext/census/grpc_filter.h",
"src/core/ext/census/mlog.h",
"src/core/ext/census/rpc_metric_id.h",
"third_party/nanopb/pb.h",
"third_party/nanopb/pb_decode.h",
"third_party/nanopb/pb_encode.h",
],
includes = [
"include",

View File

@ -1,83 +1,89 @@
SOURCES = [
"jaricom.c",
"jcapimin.c",
"jcapistd.c",
"jcarith.c",
"jccoefct.c",
"jccolor.c",
"jcdctmgr.c",
"jchuff.c",
"jcinit.c",
"jcmainct.c",
"jcmarker.c",
"jcmaster.c",
"jcomapi.c",
"jcparam.c",
"jcprepct.c",
"jcsample.c",
"jctrans.c",
"jdarith.c",
"jdapimin.c",
"jdapistd.c",
"jdatadst.c",
"jdatasrc.c",
"jdcoefct.c",
"jdcolor.c",
"jddctmgr.c",
"jdhuff.c",
"jdinput.c",
"jdmainct.c",
"jdmarker.c",
"jdmaster.c",
"jdmerge.c",
"jdpostct.c",
"jdsample.c",
"jdtrans.c",
"jerror.c",
"jfdctflt.c",
"jfdctfst.c",
"jfdctint.c",
"jidctflt.c",
"jidctfst.c",
"jidctint.c",
"jmemmgr.c",
"jmemnobs.c",
"jquant1.c",
"jquant2.c",
"jutils.c",
]
# Description:
# The Independent JPEG Group's JPEG runtime library.
HEADERS = [
"cderror.h",
"cdjpeg.h",
"jconfig.h",
"jdct.h",
"jerror.h",
"jinclude.h",
"jmemsys.h",
"jmorecfg.h",
"jpegint.h",
"jpeglib.h",
"jversion.h",
"transupp.h",
]
prefix_dir = "jpeg-9a"
genrule(
name = "configure",
srcs = glob(
["**/*"],
exclude = [prefix_dir + "/jconfig.h"],
),
outs = [prefix_dir + "/jconfig.h"],
cmd = "pushd external/jpeg_archive/%s; workdir=$$(mktemp -d -t tmp.XXXXXXXXXX); cp -a * $$workdir; pushd $$workdir; ./configure; popd; popd; cp $$workdir/jconfig.h $(@D); rm -rf $$workdir;" % prefix_dir,
)
licenses(["notice"]) # custom notice-style license, see LICENSE
cc_library(
name = "jpeg",
srcs = [prefix_dir + "/" + source for source in SOURCES],
hdrs = glob(["**/*.h"]) + [":configure"],
includes = [prefix_dir],
srcs = [
"cderror.h",
"cdjpeg.h",
"jaricom.c",
"jcapimin.c",
"jcapistd.c",
"jcarith.c",
"jccoefct.c",
"jccolor.c",
"jcdctmgr.c",
"jchuff.c",
"jcinit.c",
"jcmainct.c",
"jcmarker.c",
"jcmaster.c",
"jcomapi.c",
"jconfig.h",
"jcparam.c",
"jcprepct.c",
"jcsample.c",
"jctrans.c",
"jdapimin.c",
"jdapistd.c",
"jdarith.c",
"jdatadst.c",
"jdatasrc.c",
"jdcoefct.c",
"jdcolor.c",
"jdct.h",
"jddctmgr.c",
"jdhuff.c",
"jdinput.c",
"jdmainct.c",
"jdmarker.c",
"jdmaster.c",
"jdmerge.c",
"jdpostct.c",
"jdsample.c",
"jdtrans.c",
"jerror.c",
"jfdctflt.c",
"jfdctfst.c",
"jfdctint.c",
"jidctflt.c",
"jidctfst.c",
"jidctint.c",
"jinclude.h",
"jmemmgr.c",
"jmemnobs.c",
"jmemsys.h",
"jmorecfg.h",
"jquant1.c",
"jquant2.c",
"jutils.c",
"jversion.h",
"transupp.h",
],
hdrs = [
"jerror.h",
"jpegint.h",
"jpeglib.h",
],
includes = ["."],
visibility = ["//visibility:public"],
)
genrule(
name = "configure",
outs = ["jconfig.h"],
cmd = "cat <<EOF >$@\n" +
"#define HAVE_PROTOTYPES 1\n" +
"#define HAVE_UNSIGNED_CHAR 1\n" +
"#define HAVE_UNSIGNED_SHORT 1\n" +
"#define HAVE_STDDEF_H 1\n" +
"#define HAVE_STDLIB_H 1\n" +
"#ifdef WIN32\n" +
"#define INLINE __inline\n" +
"#else\n" +
"#define INLINE __inline__\n" +
"#endif\n" +
"EOF\n",
)

View File

@ -1,34 +1,31 @@
licenses(["notice"]) # MIT
JSON_HEADERS = [
"include/json/assertions.h",
"include/json/autolink.h",
"include/json/config.h",
"include/json/features.h",
"include/json/forwards.h",
"include/json/json.h",
"src/lib_json/json_batchallocator.h",
"include/json/reader.h",
"include/json/value.h",
"include/json/writer.h",
]
JSON_SOURCES = [
"src/lib_json/json_reader.cpp",
"src/lib_json/json_value.cpp",
"src/lib_json/json_writer.cpp",
"src/lib_json/json_tool.h",
]
INLINE_SOURCES = [
"src/lib_json/json_valueiterator.inl",
]
licenses(["unencumbered"]) # Public Domain or MIT
cc_library(
name = "jsoncpp",
srcs = JSON_SOURCES,
hdrs = JSON_HEADERS,
srcs = [
"include/json/assertions.h",
"src/lib_json/json_batchallocator.h",
"src/lib_json/json_reader.cpp",
"src/lib_json/json_tool.h",
"src/lib_json/json_value.cpp",
"src/lib_json/json_writer.cpp",
],
hdrs = [
"include/json/autolink.h",
"include/json/config.h",
"include/json/features.h",
"include/json/forwards.h",
"include/json/json.h",
"include/json/reader.h",
"include/json/value.h",
"include/json/writer.h",
],
includes = ["include"],
textual_hdrs = INLINE_SOURCES,
visibility = ["//visibility:public"],
deps = [":private"],
)
cc_library(
name = "private",
textual_hdrs = ["src/lib_json/json_valueiterator.inl"],
)

13
linenoise.BUILD Normal file
View File

@ -0,0 +1,13 @@
licenses(["notice"]) # 2-clause BSD
exports_files(["LICENSE"])
package(
default_visibility = ["//visibility:public"],
)
cc_library(
name = "linenoise",
srcs = ["linenoise.c"],
hdrs = ["linenoise.h"],
)

View File

@ -1,19 +1,21 @@
SOURCES = [
"pb_common.c",
"pb_decode.c",
"pb_encode.c",
]
# Description:
# Nanopb, a tiny ANSI C protobuf implementation for use on embedded devices.
HEADERS = [
"pb.h",
"pb_common.h",
"pb_decode.h",
"pb_encode.h",
]
licenses(["notice"]) # zlib license
cc_library(
name = "nanopb",
srcs = SOURCES,
hdrs = HEADERS,
srcs = [
"pb_common.c",
"pb_decode.c",
"pb_encode.c",
],
hdrs = [
"pb.h",
"pb_common.h",
"pb_decode.h",
"pb_encode.h",
],
includes = ["."],
visibility = ["//visibility:public"],
)

View File

@ -1,40 +1,33 @@
package(default_visibility = ["//visibility:public"])
# Description:
# libpng is the official PNG reference library.
prefix_dir = "libpng-1.2.53"
PNG_SOURCES = [
"png.c",
"pngerror.c",
"pngget.c",
"pngmem.c",
"pngpread.c",
"pngread.c",
"pngrio.c",
"pngrtran.c",
"pngrutil.c",
"pngset.c",
"pngtrans.c",
"pngwio.c",
"pngwrite.c",
"pngwtran.c",
"pngwutil.c",
]
genrule(
name = "configure",
srcs = glob(
["**/*"],
exclude = [prefix_dir + "/config.h"],
),
outs = [prefix_dir + "/config.h"],
cmd = "pushd external/png_archive/%s; workdir=$$(mktemp -d -t tmp.XXXXXXXXXX); cp -a * $$workdir; pushd $$workdir; ./configure --enable-shared=no --with-pic=no; popd; popd; cp $$workdir/config.h $(@D); rm -rf $$workdir;" % prefix_dir,
)
licenses(["notice"]) # BSD/MIT-like license
cc_library(
name = "png",
srcs = [prefix_dir + "/" + source for source in PNG_SOURCES],
hdrs = glob(["**/*.h"]) + [":configure"],
includes = [prefix_dir],
linkopts = ["-lz"],
srcs = [
"png.c",
"pngerror.c",
"pngget.c",
"pngmem.c",
"pngpread.c",
"pngread.c",
"pngrio.c",
"pngrtran.c",
"pngrutil.c",
"pngset.c",
"pngtrans.c",
"pngwio.c",
"pngwrite.c",
"pngwtran.c",
"pngwutil.c",
],
hdrs = [
"png.h",
"pngconf.h",
],
includes = ["."],
linkopts = ["-lm"],
visibility = ["//visibility:public"],
deps = ["@zlib_archive//:zlib"],
)

View File

@ -1,13 +1,12 @@
genrule(
name = "copy_six",
srcs = ["six-1.10.0/six.py"],
outs = ["six.py"],
cmd = "cp $< $(@)",
)
# Description:
# Six provides simple utilities for wrapping over differences between Python 2
# and Python 3.
licenses(["notice"]) # MIT
py_library(
name = "six",
srcs = ["six.py"],
visibility = ["//visibility:public"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)

View File

@ -93,10 +93,12 @@ filegroup(
":all_files",
"//tensorflow/c:all_files",
"//tensorflow/cc:all_files",
"//tensorflow/cc/saved_model:all_files",
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/android:all_files",
"//tensorflow/contrib/bayesflow:all_files",
"//tensorflow/contrib/copy_graph:all_files",
"//tensorflow/contrib/crf:all_files",
"//tensorflow/contrib/cudnn_rnn:all_files",
"//tensorflow/contrib/distributions:all_files",
"//tensorflow/contrib/factorization:all_files",
@ -129,7 +131,11 @@ filegroup(
"//tensorflow/contrib/slim/python/slim/nets:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
"//tensorflow/contrib/tensorboard:all_files",
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/tfprof/tools/tfprof:all_files",
"//tensorflow/contrib/tfprof/tools/tfprof/internal:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/core:all_files",
@ -142,6 +148,7 @@ filegroup(
"//tensorflow/core/platform/default/build_config:all_files",
"//tensorflow/core/platform/hadoop:all_files",
"//tensorflow/core/util/ctc:all_files",
"//tensorflow/core/util/tensor_bundle:all_files",
"//tensorflow/examples/android:all_files",
"//tensorflow/examples/how_tos/reading_data:all_files",
"//tensorflow/examples/image_retraining:all_files",
@ -166,6 +173,7 @@ filegroup(
"//tensorflow/python/debug:all_files",
"//tensorflow/python/kernel_tests:all_files",
"//tensorflow/python/saved_model:all_files",
"//tensorflow/python/saved_model/example:all_files",
"//tensorflow/python/tools:all_files",
"//tensorflow/tensorboard:all_files",
"//tensorflow/tensorboard/app:all_files",
@ -176,7 +184,6 @@ filegroup(
"//tensorflow/tensorboard/lib:all_files",
"//tensorflow/tensorboard/lib/python:all_files",
"//tensorflow/tensorboard/scripts:all_files",
"//tensorflow/third_party/hadoop:all_files",
"//tensorflow/tools/dist_test/server:all_files",
"//tensorflow/tools/docker:all_files",
"//tensorflow/tools/docker/notebooks:all_files",
@ -185,6 +192,7 @@ filegroup(
"//tensorflow/tools/proto_text:all_files",
"//tensorflow/tools/test:all_files",
"//tensorflow/user_ops:all_files",
"//third_party/hadoop:all_files",
],
visibility = [":__subpackages__"],
)

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
@ -1574,6 +1575,40 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
status->status = MessageToBuffer(def, output_graph_def);
}
struct TF_ImportGraphDefOptions {
tensorflow::ImportGraphDefOptions opts;
};
TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
return new TF_ImportGraphDefOptions;
}
void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
delete opts;
}
void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
const char* prefix) {
opts->opts.prefix = prefix;
}
void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* opts,
TF_Status* status) {
GraphDef def;
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef");
return;
}
mutex_lock l(graph->mu);
const int last_node_id = graph->graph.num_node_ids();
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
&graph->refiner);
if (!status->status.ok()) return;
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
auto* node = graph->graph.FindNodeId(i);
if (node != nullptr) graph->name_map[node->name()] = node;
}
}
// TF_SessionWithGraph functions ----------------------------------------------
TF_SessionWithGraph* TF_NewSessionWithGraph(TF_Graph* graph,

View File

@ -739,20 +739,36 @@ extern TF_Operation* TF_GraphOperationByName(TF_Graph* graph,
// }
extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos);
// Note: The following two functions may fail on very large protos in the
// future.
// Write out a serialized representation of `graph` (as a GraphDef protocol
// message) to `output_graph_def`.
//
// May fail on very large graphs in the future.
extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
TF_Status* status);
// TF_ImportGraphDefOptions holds options that can be passed to
// TF_GraphImportGraphDef.
typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions;
extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions();
extern void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts);
// Set the prefix to be prepended to the names of nodes in `graph_def` that will
// be imported into `graph`.
extern void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
const char* prefix);
// Import the graph serialized in `graph_def` into `graph`.
extern void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options,
TF_Status* status);
// Note: The following function may fail on very large protos in the future.
extern void TF_OperationToNodeDef(TF_Operation* oper,
TF_Buffer* output_node_def,
TF_Status* status);
// TODO(cwhipkey): Query shape for operation outputs.
// TODO(ashankar): Import GraphDef into TF_Graph.
// TODO(andydavis): Function to add gradients to a graph.
// TODO(josh11b): Register OpDef, available to all operations added

View File

@ -45,7 +45,7 @@ TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src);
namespace {
TEST(CApi, Status) {
TEST(CAPI, Status) {
TF_Status* s = TF_NewStatus();
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(string(), TF_Message(s));
@ -60,7 +60,7 @@ static void Deallocator(void* data, size_t, void* arg) {
*reinterpret_cast<bool*>(arg) = true;
}
TEST(CApi, Tensor) {
TEST(CAPI, Tensor) {
const int num_bytes = 6 * sizeof(float);
float* values =
reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
@ -80,7 +80,7 @@ TEST(CApi, Tensor) {
EXPECT_TRUE(deallocator_called);
}
TEST(CApi, AllocateTensor) {
TEST(CAPI, AllocateTensor) {
const int num_bytes = 6 * sizeof(float);
int64_t dims[] = {2, 3};
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, num_bytes);
@ -92,7 +92,7 @@ TEST(CApi, AllocateTensor) {
TF_DeleteTensor(t);
}
TEST(CApi, LibraryLoadFunctions) {
TEST(CAPI, LibraryLoadFunctions) {
// Load the library.
TF_Status* status = TF_NewStatus();
TF_Library* lib =
@ -139,7 +139,7 @@ static void TestEncodeDecode(int line, const std::vector<string>& data) {
}
}
TEST(CApi, TensorEncodeDecodeStrings) {
TEST(CAPI, TensorEncodeDecodeStrings) {
TestEncodeDecode(__LINE__, {});
TestEncodeDecode(__LINE__, {"hello"});
TestEncodeDecode(__LINE__,
@ -149,12 +149,12 @@ TEST(CApi, TensorEncodeDecodeStrings) {
TestEncodeDecode(__LINE__, {"small", big, "small2"});
}
TEST(CApi, SessionOptions) {
TEST(CAPI, SessionOptions) {
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_DeleteSessionOptions(opt);
}
TEST(CApi, SessionWithRunMetadata) {
TEST(CAPI, SessionWithRunMetadata) {
TF_Status* s = TF_NewStatus();
TF_SessionOptions* opt = TF_NewSessionOptions();
TF_Session* session = TF_NewSession(opt, s);
@ -230,7 +230,7 @@ TEST(CAPI, StatusEnum) {
EXPECT_EQ(TF_DATA_LOSS, static_cast<TF_Code>(tensorflow::error::DATA_LOSS));
}
TEST(CApi, GetAllOpList) {
TEST(CAPI, GetAllOpList) {
TF_Buffer* buf = TF_GetAllOpList();
tensorflow::OpList op_list;
EXPECT_TRUE(op_list.ParseFromArray(buf->data, buf->length));
@ -646,6 +646,47 @@ TEST(CAPI, Graph) {
TF_DeleteStatus(s);
}
TEST(CAPI, ImportGraphDef) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Create a graph with two nodes: x and 3
Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
ScalarConst(3, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
// Export to a GraphDef
TF_Buffer* graph_def = TF_NewBuffer();
TF_GraphToGraphDef(graph, graph_def, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Import it again, with a prefix, in a fresh graph.
TF_DeleteGraph(graph);
graph = TF_NewGraph();
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
TF_GraphImportGraphDef(graph, graph_def, opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteImportGraphDefOptions(opts);
TF_DeleteBuffer(graph_def);
TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar");
TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed");
ASSERT_TRUE(scalar != nullptr);
ASSERT_TRUE(feed != nullptr);
// Can add nodes to the imported graph without trouble.
Add(feed, scalar, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
class CSessionWithGraph {
public:
CSessionWithGraph(TF_Graph* graph, TF_Status* s) {

View File

@ -48,6 +48,42 @@ tf_cc_test(
],
)
cc_library(
name = "gradient_checker",
srcs = ["framework/gradient_checker.cc"],
hdrs = ["framework/gradient_checker.h"],
deps = [
":cc_ops",
":client_session",
":grad_op_registry",
":gradients",
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "framework_gradient_checker_test",
srcs = ["framework/gradient_checker_test.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":grad_ops",
":gradient_checker",
":testutil",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "grad_ops",
deps = [

View File

@ -0,0 +1,165 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/standard_ops.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
namespace {
// TODO(andydavis) Support returning relative error (as opposed to max error)
// between theoretical and numerical jacobians:
// fabs(jac_t - jac_n) / max(fabs(jac_t), fabs(jac_n))
// TODO(andydavis) Vectorize and/or multi-thread Jacobian computations if
// performance becomes an issue.
template <typename T>
Status ComputeTheoreticalJacobianTranspose(
const Scope& scope, const ops::Output& x, const TensorShape& x_shape,
const Tensor& x_data, const ops::Output& y, const TensorShape& y_shape,
Tensor* jacobian_t) {
// Call AddSymbolicGradients to get 'dx' (we will feed 'dy').
auto dy = Cast(scope, Const(scope, 1.0, y_shape), x.type());
std::vector<ops::Output> outputs;
TF_RETURN_IF_ERROR(AddSymbolicGradients(scope, {y}, {x}, {dy}, &outputs));
auto dx = outputs[0];
// Initialize 'dy_data' to zeros.
Tensor dy_data(y.type(), y_shape);
auto dy_data_flat = dy_data.flat<T>();
dy_data_flat.setZero();
// Compute the theoretical Jacobian one row at a time by backproping '1.0'
// for each element of 'dy', while holding all other elements of 'dy' at zero.
ClientSession session(scope);
std::vector<Tensor> dxout;
const int64 x_size = x_shape.num_elements();
const int64 dy_size = y_shape.num_elements();
auto jacobian = jacobian_t->matrix<T>();
for (int c = 0; c < dy_size; ++c) {
dy_data_flat(c) = 1.0;
TF_RETURN_IF_ERROR(session.Run({{x, x_data}, {dy, dy_data}}, {dx}, &dxout));
auto dx_flat = dxout[0].flat<T>();
for (int r = 0; r < x_size; ++r) {
jacobian(r, c) = dx_flat(r);
}
dy_data_flat(c) = 0.0;
}
return Status::OK();
}
template <typename T>
Status ComputeNumericJacobianTranspose(const Scope& scope, const ops::Output& x,
const TensorShape& x_shape,
const ops::Output& y,
const TensorShape& y_shape,
const T delta, Tensor* x_data,
Tensor* jacobian_t) {
const int64 x_size = x_shape.num_elements();
const int64 y_size = y_shape.num_elements();
auto x_data_flat = x_data->flat<T>();
// Compute the numeric Jacobian one column at a time by perturbing each
// element of 'x_data' (positively and negatively) by 'delta', and
// updating the jacobian with the centered difference.
ClientSession session(scope);
std::vector<Tensor> yout;
auto jacobian = jacobian_t->matrix<T>();
for (int r = 0; r < x_size; ++r) {
// Store current value of 'x' at 'r'.
T v = x_data_flat(r);
// Evaluate at positive delta.
x_data_flat(r) = v + delta;
TF_RETURN_IF_ERROR(session.Run({{x, *x_data}}, {y}, &yout));
Tensor y_pos = yout[0];
// Evaluate at negative delta.
x_data_flat(r) = v - delta;
TF_RETURN_IF_ERROR(session.Run({{x, *x_data}}, {y}, &yout));
Tensor y_neg = yout[0];
// Compute element-wise centered difference and store in Jacobian.
auto y_pos_flat = y_pos.flat<T>();
auto y_neg_flat = y_neg.flat<T>();
const T scale = 2 * delta;
for (int c = 0; c < y_size; ++c) {
jacobian(r, c) = (y_pos_flat(c) - y_neg_flat(c)) / scale;
}
// Restore pre-perturbation value.
x_data_flat(r) = v;
}
return Status::OK();
}
} // namespace
template <typename T>
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
const TensorShape& x_shape, const ops::Output& y,
const TensorShape& y_shape, T* max_error) {
const int64 x_size = x_shape.num_elements();
const int64 y_size = y_shape.num_elements();
// Initialize 'x_data' to random values.
Tensor x_data(x.type(), x_shape);
auto x_data_flat = x_data.flat<T>();
x_data_flat.setRandom();
// Initialize theoretical Jacobian to zeros.
Tensor jacobian_t(x.type(), {x_size, y_size});
auto jacobian_t_flat = jacobian_t.flat<T>();
jacobian_t_flat.setZero();
// Compute theoretical Jacobian.
TF_RETURN_IF_ERROR(ComputeTheoreticalJacobianTranspose<T>(
scope, x, x_shape, x_data, y, y_shape, &jacobian_t));
// Inititalize numeric Jacobian to zeros.
Tensor jacobian_n(x.type(), {x_size, y_size});
auto jacobian_n_flat = jacobian_n.flat<T>();
jacobian_n_flat.setZero();
// Compute numeric Jacobian.
TF_RETURN_IF_ERROR(ComputeNumericJacobianTranspose<T>(
scope, x, x_shape, y, y_shape, 1e-3, &x_data, &jacobian_n));
// Compute the maximum error between theoretical and numeric Jacobians.
*max_error = 0.0;
auto jac_t = jacobian_t.matrix<T>();
auto jac_n = jacobian_n.matrix<T>();
for (int r = 0; r < x_size; ++r) {
for (int c = 0; c < y_size; ++c) {
*max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
}
}
return Status::OK();
}
#define INSTANTIATE_GRAD_ERR_TYPE(T) \
template Status ComputeGradientError<T>( \
const Scope& scope, const ops::Output& x, const TensorShape& x_shape, \
const ops::Output& y, const TensorShape& y_shape, T* max_error)
INSTANTIATE_GRAD_ERR_TYPE(float);
INSTANTIATE_GRAD_ERR_TYPE(double);
} // namespace tensorflow

View File

@ -0,0 +1,35 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
// Returns in 'max_error' the maximum element-wise error for dy/dx between the
// computed and numeric Jacobian matrices where 'x' and 'y' are tensors.
// This function adds operations to the graph associated with 'scope'.
template <typename T>
Status ComputeGradientError(const Scope& scope, const ops::Output& x,
const TensorShape& x_shape, const ops::Output& y,
const TensorShape& y_shape, T* max_error);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_

View File

@ -0,0 +1,70 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
namespace {
TEST(GradientCheckerTest, BasicFloat) {
Scope scope = Scope::NewRootScope();
TensorShape shape({2, 4, 3});
auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
auto y = Square(scope, x);
float max_error;
TF_ASSERT_OK(
ComputeGradientError<float>(scope, x, shape, y, shape, &max_error));
EXPECT_LT(max_error, 1e-4);
}
TEST(GradientCheckerTest, BasicDouble) {
Scope scope = Scope::NewRootScope();
TensorShape shape({2, 4, 3});
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(shape));
auto y = Square(scope, x);
double max_error;
TF_ASSERT_OK(
ComputeGradientError<double>(scope, x, shape, y, shape, &max_error));
EXPECT_LT(max_error, 1e-10);
}
TEST(GradientCheckerTest, MatMulGrad) {
Scope scope = Scope::NewRootScope();
TensorShape x_shape({4, 3});
TensorShape y_shape({3, 2});
TensorShape z_shape({4, 2});
auto x = Placeholder(scope, DT_DOUBLE, Placeholder::Shape(x_shape));
auto y = Const(scope, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, y_shape);
auto z = MatMul(scope, x, y);
double max_error;
TF_ASSERT_OK(
ComputeGradientError<double>(scope, x, x_shape, z, z_shape, &max_error));
EXPECT_LT(max_error, 1e-10);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,68 @@
# Description:
# TensorFlow SavedModel.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "constants",
hdrs = ["constants.h"],
)
cc_library(
name = "loader",
srcs = ["loader.cc"],
hdrs = ["loader.h"],
deps = [
":constants",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
tf_cc_test(
name = "loader_test",
srcs = ["loader_test.cc"],
data = [
":saved_model_half_plus_two",
],
linkstatic = 1,
deps = [
":constants",
":loader",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
filegroup(
name = "saved_model_half_plus_two",
srcs = glob(["testdata/half_plus_two/*"]),
)
# -----------------------------------------------------------------------------
# Google-internal targets.
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,36 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
namespace tensorflow {
// SavedModel proto filename.
constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
// SavedModel text format proto filename.
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
// SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "saved_model_variables";
// Commonly used tags.
constexpr char kSavedModelTagServe[] = "serve";
constexpr char kSavedModelTagTrain[] = "train";
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_

View File

@ -0,0 +1,126 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader.h"
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
const string saved_model_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
return ReadBinaryProto(Env::Default(), saved_model_path, saved_model_proto);
}
Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
const std::unordered_set<string>& tags,
MetaGraphDef* meta_graph_def_to_load) {
for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) {
// Get tags from the meta_graph_def.
std::unordered_set<string> graph_tags;
for (const string& tag : meta_graph_def.meta_info_def().tags()) {
graph_tags.insert(tag);
}
// Match with the set of tags provided.
if (graph_tags == tags) {
*meta_graph_def_to_load = meta_graph_def;
return Status::OK();
}
}
return Status(error::Code::NOT_FOUND,
"Could not find meta graph def matching supplied tags.");
}
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
session->reset(NewSession(session_options));
return (*session)->Create(meta_graph_def.graph_def());
}
Status Restore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
Session* session) {
const string variables_path =
io::JoinPath(export_dir, kSavedModelVariablesFilename);
if (!Env::Default()->FileExists(variables_path)) {
return Status(error::Code::NOT_FOUND,
"Could not find checkpointed variables.");
}
// Add variables to the graph.
Tensor variables_path_tensor(DT_STRING, TensorShape({}));
variables_path_tensor.scalar<string>()() = variables_path;
std::vector<std::pair<string, Tensor>> inputs = {
{variable_filename_const_op_name.ToString(), variables_path_tensor}};
RunMetadata run_metadata;
return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
nullptr /* outputs */, &run_metadata);
}
} // namespace
Status LoadSavedModel(const string& export_dir,
const std::unordered_set<string>& tags,
const SessionOptions& session_options,
const RunOptions& run_options,
SavedModelBundle* const bundle) {
if (!MaybeSavedModelDirectory(export_dir)) {
return Status(error::Code::NOT_FOUND,
"SavedModel not found in export directory: " + export_dir);
}
LOG(INFO) << "Loading SavedModel from: " << export_dir;
SavedModel saved_model_proto;
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
TF_RETURN_IF_ERROR(
FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def));
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
bundle->meta_graph_def, session_options, &bundle->session));
TF_RETURN_IF_ERROR(
Restore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(),
bundle->session.get()));
LOG(INFO) << "Done loading SavedModel.";
return Status::OK();
}
bool MaybeSavedModelDirectory(const string& export_dir) {
const string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
const string saved_model_pbtxt_path =
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
return Env::Default()->FileExists(saved_model_pb_path) ||
Env::Default()->FileExists(saved_model_pbtxt_path);
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// SavedModel loading functions and SavedModelBundle struct.
#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
#include <string>
#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
// SavedModel representation once the SavedModel is loaded from storage.
struct SavedModelBundle {
std::unique_ptr<Session> session;
MetaGraphDef meta_graph_def;
};
// Loads a SavedModel from the specified export directory. The meta graph def to
// be loaded is identified by the supplied tags, corresponding exactly to the
// set of tags used at SavedModel build time. Returns a SavedModel bundle with a
// session and the requested meta graph def, if found.
Status LoadSavedModel(const string& export_dir,
const std::unordered_set<string>& tags,
const SessionOptions& session_options,
const RunOptions& run_options,
SavedModelBundle* const bundle);
// Checks whether the provided directory could contain a SavedModel. Note that
// the method does not load any data by itself. If the method returns `false`,
// the export directory definitely does not contain a SavedModel. If the method
// returns `true`, the export directory may contain a SavedModel but provides no
// guarantee that it can be loaded.
bool MaybeSavedModelDirectory(const string& export_dir);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_

View File

@ -0,0 +1,129 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestData[] = "cc/saved_model/testdata/half_plus_two";
class LoaderTest : public ::testing::Test {
protected:
LoaderTest() {}
void CheckSavedModelBundle(const SavedModelBundle& bundle) {
// Validate the half plus two behavior.
Tensor input = test::AsTensor<float>({0, 1, 2, 3}, TensorShape({4, 1}));
// Retrieve the regression signature from meta graph def.
const auto signature_def_map = bundle.meta_graph_def.signature_def();
const auto signature_def = signature_def_map.at("regression");
const string input_name = signature_def.inputs().at("input").name();
const string output_name = signature_def.outputs().at("output").name();
std::vector<Tensor> outputs;
TF_ASSERT_OK(bundle.session->Run({{input_name, input}}, {output_name}, {},
&outputs));
ASSERT_EQ(outputs.size(), 1);
test::ExpectTensorEqual<float>(
outputs[0],
test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
}
};
TEST_F(LoaderTest, TagMatch) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData);
TF_ASSERT_OK(LoadSavedModel(export_dir, {kSavedModelTagServe},
session_options, run_options, &bundle));
CheckSavedModelBundle(bundle);
}
TEST_F(LoaderTest, NoTagMatch) {
SavedModelBundle bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData);
Status st = LoadSavedModel(export_dir, {"missing-tag"}, session_options,
run_options, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(
StringPiece(st.error_message())
.contains("Could not find meta graph def matching supplied tags."))
<< st.error_message();
}
TEST_F(LoaderTest, NoTagMatchMultiple) {
SavedModelBundle bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData);
Status st = LoadSavedModel(export_dir, {kSavedModelTagServe, "missing-tag"},
session_options, run_options, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(
StringPiece(st.error_message())
.contains("Could not find meta graph def matching supplied tags."))
<< st.error_message();
}
TEST_F(LoaderTest, InvalidExportPath) {
SavedModelBundle bundle;
RunOptions run_options;
SessionOptions session_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
Status st = LoadSavedModel(export_dir, {kSavedModelTagServe}, session_options,
run_options, &bundle);
EXPECT_FALSE(st.ok());
}
TEST_F(LoaderTest, MaybeSavedModelDirectory) {
// Valid SavedModel directory.
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData);
EXPECT_TRUE(MaybeSavedModelDirectory(export_dir));
// Directory that does not exist.
const string missing_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
EXPECT_FALSE(MaybeSavedModelDirectory(missing_export_dir));
// Directory that exists but is an invalid SavedModel location.
const string invalid_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata");
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,2 @@
model_checkpoint_path: "/tmp/saved_model/half_plus_two/saved_model_variables"
all_model_checkpoint_paths: "/tmp/saved_model/half_plus_two/saved_model_variables"

Binary file not shown.

View File

@ -15,6 +15,7 @@ py_library(
deps = [
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/copy_graph:copy_graph_py",
"//tensorflow/contrib/crf:crf_py",
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/factorization:factorization_py",
@ -35,6 +36,7 @@ py_library(
"//tensorflow/contrib/slim:nets",
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/contrib/tensor_forest/hybrid:ops_lib",
"//tensorflow/contrib/tensorboard",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",

View File

@ -21,6 +21,7 @@ from __future__ import print_function
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import bayesflow
from tensorflow.contrib import copy_graph
from tensorflow.contrib import crf
from tensorflow.contrib import cudnn_rnn
from tensorflow.contrib import distributions
from tensorflow.contrib import factorization
@ -38,6 +39,7 @@ from tensorflow.contrib import quantization
from tensorflow.contrib import rnn
from tensorflow.contrib import slim
from tensorflow.contrib import tensor_forest
from tensorflow.contrib import tensorboard
from tensorflow.contrib import testing
from tensorflow.contrib import training
from tensorflow.contrib import util

View File

@ -55,7 +55,7 @@ Log E_q[ f(Z) p(Z) / q(Z) ]
C := Max[ Log[f(Z)] + Log[p(Z)] - Log[q(Z)] ].
```
The maximum value of the exponentiated term will be 0.0, and the the expecation
The maximum value of the exponentiated term will be 0.0, and the the expectation
can be evaluated in a stable manner.
## Ops
@ -252,9 +252,7 @@ def expectation(f, p, z=None, n=None, seed=None, name='expectation'):
User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
Args:
f: Callable mapping samples from `sampling_dist_q` to `Tensors` with
shape broadcastable to `q.batch_shape`.
For example, `f` works "just like" `sampling_dist_q.log_prob`.
f: Callable mapping samples from `p` to `Tensors`.
p: `tf.contrib.distributions.BaseDistribution`.
z: `Tensor` of samples from `p`, produced by `p.sample_n`.
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
@ -262,7 +260,36 @@ def expectation(f, p, z=None, n=None, seed=None, name='expectation'):
name: A name to give this `Op`.
Returns:
A `Tensor` with same `dtype` as `p`, and shape equal to `p.batch_shape`.
A `Tensor` with the same `dtype` as `p`.
Example:
```python
N_samples = 10000
distributions = tf.contrib.distributions
dist = distributions.Uniform([0.0, 0.0], [1.0, 2.0])
elementwise_mean = lambda x: x
mean_sum = lambda x: tf.reduce_sum(x, 1)
estimate_elementwise_mean_tf = monte_carlo.expectation(elementwise_mean,
dist,
n=N_samples)
estimate_mean_sum_tf = monte_carlo.expectation(mean_sum,
dist,
n=N_samples)
with tf.Session() as sess:
estimate_elementwise_mean, estimate_mean_sum = (
sess.run([estimate_elementwise_mean_tf, estimate_mean_sum_tf]))
print estimate_elementwise_mean
>>> np.array([ 0.50018013 1.00097895], dtype=np.float32)
print estimate_mean_sum
>>> 1.49571
```
"""
with ops.name_scope(name, values=[n, z]):
z = _get_samples(p, z, n, seed)

View File

@ -19,27 +19,19 @@ from __future__ import print_function
import fnmatch
import os
import platform
import re
import sys
from setuptools import find_packages, setup, Command, Extension
from setuptools import find_packages, setup, Command
from setuptools.command.install import install as InstallCommandBase
from setuptools.dist import Distribution
_VERSION = '0.10.0-cmake-experimental'
numpy_version = "1.8.2"
if platform.system() == "Darwin":
# There are bugs with numpy pip installation on OS X prior to
# 1.10.1, so on mac we require a higher version than on other
# platforms.
numpy_version = "1.10.1"
REQUIRED_PACKAGES = [
'numpy >= %s' % numpy_version,
'numpy >= 1.11.0',
'six >= 1.10.0',
'protobuf == 3.0.0b2',
'protobuf == 3.0.0',
]
# python3 requires wheel 0.26

View File

@ -0,0 +1,40 @@
# Description:
# Contains classes to construct a CRF layer
# APIs here are meant to evolve over time.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
py_library(
name = "crf_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
srcs_version = "PY2AND3",
)
cuda_py_tests(
name = "crf_test",
srcs = ["python/kernel_tests/crf_test.py"],
additional_deps = [
":crf_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,76 @@
# CRF
The CRF module implements a linear-chain CRF layer for learning to predict tag sequences. This variant of the CRF is factored into unary potentials for every element in the sequence and binary potentials for every transition between output tags.
### Usage
Below is an example of the API, which learns a CRF for some random data. The linear layer in the example can be replaced by any neural network.
```python
import numpy as np
import tensorflow as tf
# Data settings.
num_examples = 10
num_words = 20
num_features = 100
num_tags = 5
# Random features.
x = np.random.rand(num_examples, num_words, num_features).astype(np.float32)
# Random tag indices representing the gold sequence.
y = np.random.randint(num_tags, size=[num_examples, num_words]).astype(np.int32)
# All sequences in this example have the same length, but they can be variable in a real model.
sequence_lengths = np.full(num_examples, num_words - 1, dtype=np.int32)
# Train and evaluate the model.
with tf.Graph().as_default():
with tf.Session() as session:
# Add the data to the TensorFlow graph.
x_t = tf.constant(x)
y_t = tf.constant(y)
sequence_lengths_t = tf.constant(sequence_lengths)
# Compute unary scores from a linear layer.
weights = tf.get_variable("weights", [num_features, num_tags])
matricized_x_t = tf.reshape(x_t, [-1, num_features])
matricized_unary_scores = tf.batch_matmul(matricized_x_t, weights)
unary_scores = tf.reshape(matricized_unary_scores,
[num_examples, num_words, num_tags])
# Compute the log-likelihood of the gold sequences and keep the transition
# params for inference at test time.
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
unary_scores, y_t, sequence_lengths_t)
# Add a training op to tune the parameters.
loss = tf.reduce_mean(-log_likelihood)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# Train for a fixed number of iterations.
session.run(tf.initialize_all_variables())
for i in range(1000):
tf_unary_scores, tf_transition_params, _ = session.run(
[unary_scores, transition_params, train_op])
if i % 100 == 0:
correct_labels = 0
total_labels = 0
for tf_unary_scores_, y_, sequence_length_ in zip(tf_unary_scores, y,
sequence_lengths):
# Remove padding from the scores and tag sequence.
tf_unary_scores_ = tf_unary_scores_[:sequence_length_]
y_ = y_[:sequence_length_]
# Compute the highest scoring sequence.
viterbi_sequence, _ = tf.contrib.crf.viterbi_decode(
tf_unary_scores_, tf_transition_params)
# Evaluate word-level accuracy.
correct_labels += np.sum(np.equal(viterbi_sequence, y_))
total_labels += sequence_length_
accuracy = 100.0 * correct_labels / float(total_labels)
print("Accuracy: %.2f%%" % accuracy)
```

View File

@ -0,0 +1,39 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linear-chain CRF layer.
## This package provides functions for building a linear-chain CRF layer.
@@crf_sequence_score
@@crf_log_norm
@@crf_log_likelihood
@@crf_unary_score
@@crf_binary_score
@@CrfForwardRnnCell
@@viterbi_decode
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks
from tensorflow.contrib.crf.python.ops.crf import crf_binary_score
from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood
from tensorflow.contrib.crf.python.ops.crf import crf_log_norm
from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score
from tensorflow.contrib.crf.python.ops.crf import crf_unary_score
from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
from tensorflow.contrib.crf.python.ops.crf import viterbi_decode

View File

@ -0,0 +1,18 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linear-chain CRF."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,200 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for CRF."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
import tensorflow as tf
class CrfTest(tf.test.TestCase):
def testCrfSequenceScore(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess:
sequence_score = tf.contrib.crf.crf_sequence_score(
inputs=tf.expand_dims(inputs, 0),
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
transition_params=tf.constant(transition_params))
sequence_score = tf.squeeze(sequence_score, [0])
tf_sequence_score = sess.run(sequence_score)
expected_unary_score = sum(inputs[i][tag_indices[i]]
for i in range(sequence_lengths))
expected_binary_score = sum(
transition_params[tag_indices[i], tag_indices[i + 1]]
for i in range(sequence_lengths - 1))
expected_sequence_score = expected_unary_score + expected_binary_score
self.assertAllClose(tf_sequence_score, expected_sequence_score)
def testCrfUnaryScore(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess:
unary_score = tf.contrib.crf.crf_unary_score(
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
inputs=tf.expand_dims(inputs, 0))
unary_score = tf.squeeze(unary_score, [0])
tf_unary_score = sess.run(unary_score)
expected_unary_score = sum(inputs[i][tag_indices[i]]
for i in range(sequence_lengths))
self.assertAllClose(tf_unary_score, expected_unary_score)
def testCrfBinaryScore(self):
tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess:
binary_score = tf.contrib.crf.crf_binary_score(
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
transition_params=tf.constant(transition_params))
binary_score = tf.squeeze(binary_score, [0])
tf_binary_score = sess.run(binary_score)
expected_binary_score = sum(
transition_params[tag_indices[i], tag_indices[i + 1]]
for i in range(sequence_lengths - 1))
self.assertAllClose(tf_binary_score, expected_binary_score)
def testCrfLogNorm(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
sequence_lengths = np.array(3, dtype=np.int32)
with self.test_session() as sess:
all_sequence_scores = []
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequence_scores.append(
tf.contrib.crf.crf_sequence_score(
inputs=tf.expand_dims(inputs, 0),
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
transition_params=tf.constant(transition_params)))
brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores)
log_norm = tf.contrib.crf.crf_log_norm(
inputs=tf.expand_dims(inputs, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
transition_params=tf.constant(transition_params))
log_norm = tf.squeeze(log_norm, [0])
tf_brute_force_log_norm, tf_log_norm = sess.run(
[brute_force_log_norm, log_norm])
self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
def testCrfLogLikelihood(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
with self.test_session() as sess:
all_sequence_log_likelihoods = []
# Make sure all probabilities sum to 1.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
sequence_log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(
inputs=tf.expand_dims(inputs, 0),
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
transition_params=tf.constant(transition_params))
all_sequence_log_likelihoods.append(sequence_log_likelihood)
total_log_likelihood = tf.reduce_logsumexp(all_sequence_log_likelihoods)
tf_total_log_likelihood = sess.run(total_log_likelihood)
self.assertAllClose(tf_total_log_likelihood, 0.0)
def testLengthsToMasks(self):
with self.test_session() as sess:
sequence_lengths = [4, 1, 8, 2]
max_sequence_length = max(sequence_lengths)
mask = tf.contrib.crf._lengths_to_masks(sequence_lengths,
max_sequence_length)
tf_mask = sess.run(mask)
self.assertEqual(len(tf_mask), len(sequence_lengths))
for m, l in zip(tf_mask, sequence_lengths):
self.assertAllEqual(m[:l], [1] * l)
self.assertAllEqual(m[l:], [0] * (len(m) - l))
def testViterbiDecode(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
with self.test_session() as sess:
all_sequence_scores = []
all_sequences = []
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequences.append(tag_indices)
sequence_score = tf.contrib.crf.crf_sequence_score(
inputs=tf.expand_dims(inputs, 0),
tag_indices=tf.expand_dims(tag_indices, 0),
sequence_lengths=tf.expand_dims(sequence_lengths, 0),
transition_params=tf.constant(transition_params))
sequence_score = tf.squeeze(sequence_score, [0])
all_sequence_scores.append(sequence_score)
tf_all_sequence_scores = sess.run(all_sequence_scores)
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
expected_max_sequence = all_sequences[expected_max_sequence_index]
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
actual_max_sequence, actual_max_score = tf.contrib.crf.viterbi_decode(
inputs[:sequence_lengths], transition_params)
self.assertAllClose(actual_max_score, expected_max_score)
self.assertEqual(actual_max_sequence,
expected_max_sequence[:sequence_lengths])
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,18 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for building a linear-chain CRF layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,311 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for constructing a linear-chain CRF.
The following snippet is an example of a CRF layer on top of a batched sequence
of unary scores (logits for every word). This example also decodes the most
likely sequence at test time:
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
unary_scores, gold_tags, sequence_lengths)
loss = tf.reduce_mean(-log_likelihood)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
[unary_scores, sequence_lengths, transition_params, train_op])
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
tf_sequence_lengths):
# Remove padding.
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
# Compute the highest score and its tag sequence.
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
tf_unary_scores_, tf_transition_params)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope as vs
__all__ = ["crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
"viterbi_decode"]
def _lengths_to_masks(lengths, max_length):
"""Creates a binary matrix that can be used to mask away padding.
Args:
lengths: A vector of integers representing lengths.
max_length: An integer indicating the maximum length. All values in
lengths should be less than max_length.
Returns:
masks: Masks that can be used to get rid of padding.
"""
tiled_ranges = array_ops.tile(
array_ops.expand_dims(math_ops.range(max_length), 0),
[array_ops.shape(lengths)[0], 1])
lengths = array_ops.expand_dims(lengths, 1)
masks = math_ops.to_float(
math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths))
return masks
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params):
"""Computes the unnormalized score for a tag sequence.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
compute the unnormalized score.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix.
Returns:
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
"""
# Compute the scores of the given tag sequence.
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
binary_scores = crf_binary_score(tag_indices, sequence_lengths,
transition_params)
sequence_scores = unary_scores + binary_scores
return sequence_scores
def crf_log_norm(inputs, sequence_lengths, transition_params):
"""Computes the normalization for a CRF.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix.
Returns:
log_norm: A [batch_size] vector of normalizers for a CRF.
"""
# Split up the first and rest of the inputs in preparation for the forward
# algorithm.
first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
first_input = array_ops.squeeze(first_input, [1])
rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
# Compute the alpha values in the forward algorithm in order to get the
# partition function.
forward_cell = CrfForwardRnnCell(transition_params)
_, alphas = rnn.dynamic_rnn(
cell=forward_cell,
inputs=rest_of_input,
sequence_length=sequence_lengths - 1,
initial_state=first_input,
dtype=dtypes.float32)
log_norm = math_ops.reduce_logsumexp(alphas, [1])
return log_norm
def crf_log_likelihood(inputs,
tag_indices,
sequence_lengths,
transition_params=None):
"""Computes the log-likehood of tag sequences in a CRF.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
compute the log-likehood.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix, if available.
Returns:
log_likelihood: A scalar containing the log-likelihood of the given sequence
of tag indices.
transition_params: A [num_tags, num_tags] transition matrix. This is either
provided by the caller or created in this function.
"""
# Get shape information.
num_tags = inputs.get_shape()[2].value
# Get the transition matrix if not provided.
if transition_params is None:
transition_params = vs.get_variable("transitions", [num_tags, num_tags])
sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params)
log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
# Normalize the scores to get the log-likelihood.
log_likelihood = sequence_scores - log_norm
return log_likelihood, transition_params
def crf_unary_score(tag_indices, sequence_lengths, inputs):
"""Computes the unary scores of tag sequences.
Args:
tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
sequence_lengths: A [batch_size] vector of true sequence lengths.
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
Returns:
unary_scores: A [batch_size] vector of unary scores.
"""
batch_size = array_ops.shape(inputs)[0]
max_seq_len = array_ops.shape(inputs)[1]
num_tags = array_ops.shape(inputs)[2]
flattened_inputs = array_ops.reshape(inputs, [-1])
offsets = array_ops.expand_dims(
math_ops.range(batch_size) * max_seq_len * num_tags, 1)
offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
unary_scores = array_ops.reshape(
array_ops.gather(flattened_inputs, flattened_tag_indices),
[batch_size, max_seq_len])
masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
return unary_scores
def crf_binary_score(tag_indices, sequence_lengths, transition_params):
"""Computes the binary scores of tag sequences.
Args:
tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] matrix of binary potentials.
Returns:
binary_scores: A [batch_size] vector of binary scores.
"""
# Get shape information.
num_tags = transition_params.get_shape()[0]
num_transitions = array_ops.shape(tag_indices)[1] - 1
# Truncate by one on each side of the sequence to get the start and end
# indices of each transition.
start_tag_indices = array_ops.slice(tag_indices, [0, 0],
[-1, num_transitions])
end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])
# Encode the indices in a flattened representation.
flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
flattened_transition_params = array_ops.reshape(transition_params, [-1])
# Get the binary scores based on the flattened representation.
binary_scores = array_ops.gather(flattened_transition_params,
flattened_transition_indices)
masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
return binary_scores
class CrfForwardRnnCell(rnn_cell.RNNCell):
"""Computes the alpha values in a linear-chain CRF.
See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
"""
def __init__(self, transition_params):
"""Initialize the CrfForwardRnnCell.
Args:
transition_params: A [num_tags, num_tags] matrix of binary potentials.
This matrix is expanded into a [1, num_tags, num_tags] in preparation
for the broadcast summation occurring within the cell.
"""
self._transition_params = array_ops.expand_dims(transition_params, 0)
self._num_tags = transition_params.get_shape()[0].value
@property
def state_size(self):
return self._num_tags
@property
def output_size(self):
return self._num_tags
def __call__(self, inputs, state, scope=None):
"""Build the CrfForwardRnnCell.
Args:
inputs: A [batch_size, num_tags] matrix of unary potentials.
state: A [batch_size, num_tags] matrix containing the previous alpha
values.
scope: Unused variable scope of this cell.
Returns:
new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
values containing the new alpha values.
"""
state = array_ops.expand_dims(state, 2)
# This addition op broadcasts self._transitions_params along the zeroth
# dimension and state along the second dimension. This performs the
# multiplication of previous alpha values and the current binary potentials
# in log space.
transition_scores = state + self._transition_params
new_alphas = inputs + math_ops.reduce_logsumexp(transition_scores, [1])
# Both the state and the output of this RNN cell contain the alphas values.
# The output value is currently unused and simply satisfies the RNN API.
# This could be useful in the future if we need to compute marginal
# probabilities, which would require the accumulated alpha values at every
# time step.
return new_alphas, new_alphas
def viterbi_decode(score, transition_params):
"""Decode the highest scoring sequence of tags outside of TensorFlow.
This should only be used at test time.
Args:
score: A [seq_len, num_tags] matrix of unary potentials.
transition_params: A [num_tags, num_tags] matrix of binary potentials.
Returns:
viterbi: A [seq_len] list of integers containing the highest scoring tag
indicies.
viterbi_score: A float containing the score for the viterbi sequence.
"""
trellis = np.zeros_like(score)
backpointers = np.zeros_like(score, dtype=np.int32)
trellis[0] = score[0]
for t in range(1, score.shape[0]):
v = np.expand_dims(trellis[t - 1], 1) + transition_params
trellis[t] = score[t] + np.max(v, 0)
backpointers[t] = np.argmax(v, 0)
viterbi = [np.argmax(trellis[-1])]
for bp in reversed(backpointers[1:]):
viterbi.append(bp[viterbi[-1]])
viterbi.reverse()
viterbi_score = np.max(trellis[-1])
return viterbi, viterbi_score

View File

@ -58,13 +58,6 @@ class KLTest(tf.test.TestCase):
self.assertAllEqual([float("nan")], kl_ok.eval())
def testRegistrationFailures(self):
with self.assertRaisesRegexp(TypeError, "is not a subclass of"):
tf.contrib.distributions.RegisterKL(
tf.contrib.distributions.Normal, object)(lambda x: x)
with self.assertRaisesRegexp(TypeError, "is not a subclass of"):
tf.contrib.distributions.RegisterKL(
object, tf.contrib.distributions.Normal)(lambda x: x)
class MyDist(tf.contrib.distributions.Normal):
pass

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import contextlib
import functools
import numpy as np
import tensorflow as tf
@ -69,9 +68,9 @@ def make_univariate_mixture(batch_shape, num_components):
logits = tf.random_uniform(
list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50.
components = [
(distributions_py.Normal,
{"mu": np.float32(np.random.randn(*list(batch_shape))),
"sigma": np.float32(10 * np.random.rand(*list(batch_shape)))})
distributions_py.Normal(
mu=np.float32(np.random.randn(*list(batch_shape))),
sigma=np.float32(10 * np.random.rand(*list(batch_shape))))
for _ in range(num_components)
]
cat = distributions_py.Categorical(logits, dtype=tf.int32)
@ -82,10 +81,10 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape):
logits = tf.random_uniform(
list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50.
components = [
(distributions_py.MultivariateNormalDiag,
{"mu": np.float32(np.random.randn(*list(batch_shape + event_shape))),
"diag_stdev": np.float32(10 * np.random.rand(
*list(batch_shape + event_shape)))})
distributions_py.MultivariateNormalDiag(
mu=np.float32(np.random.randn(*list(batch_shape + event_shape))),
diag_stdev=np.float32(10 * np.random.rand(
*list(batch_shape + event_shape))))
for _ in range(num_components)
]
cat = distributions_py.Categorical(logits, dtype=tf.int32)
@ -116,7 +115,7 @@ class MixtureTest(tf.test.TestCase):
r"cat.num_classes != len"):
distributions_py.Mixture(
distributions_py.Categorical([0.1, 0.5]), # 2 classes
[(distributions_py.Normal, {"mu": 1.0, "sigma": 2.0})])
[distributions_py.Normal(mu=1.0, sigma=2.0)])
with self.assertRaisesWithPredicateMatch(
ValueError, r"\(\) and \(2,\) are not compatible"):
# The value error is raised because the batch shapes of the
@ -124,13 +123,13 @@ class MixtureTest(tf.test.TestCase):
# vector of size (2,).
distributions_py.Mixture(
distributions_py.Categorical([-0.5, 0.5]), # scalar batch
[(distributions_py.Normal, {"mu": 1.0, "sigma": 2.0}), # scalar dist
(distributions_py.Normal, {"mu": [1.0, 1.0], "sigma": [2.0, 2.0]})])
[distributions_py.Normal(mu=1.0, sigma=2.0), # scalar dist
distributions_py.Normal(mu=[1.0, 1.0], sigma=[2.0, 2.0])])
with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"):
cat_logits = tf.placeholder(shape=[1, None], dtype=tf.int32)
distributions_py.Mixture(
distributions_py.Categorical(cat_logits),
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]})])
[distributions_py.Normal(mu=[1.0], sigma=[2.0])])
def testBrokenShapesDynamic(self):
with self.test_session():
@ -138,8 +137,8 @@ class MixtureTest(tf.test.TestCase):
d1_param = tf.placeholder(dtype=tf.float32)
d = distributions_py.Mixture(
distributions_py.Categorical([0.1, 0.2]),
[(distributions_py.Normal, {"mu": d0_param, "sigma": d0_param}),
(distributions_py.Normal, {"mu": d1_param, "sigma": d1_param})],
[distributions_py.Normal(mu=d0_param, sigma=d0_param),
distributions_py.Normal(mu=d1_param, sigma=d1_param)],
validate_args=True)
with self.assertRaisesOpError(r"batch shape must match"):
d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]})
@ -150,42 +149,24 @@ class MixtureTest(tf.test.TestCase):
with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"):
distributions_py.Mixture(None, [])
cat = distributions_py.Categorical([0.3, 0.2])
# components must be a list of tuples
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
# components must be a list of distributions
with self.assertRaisesWithPredicateMatch(
TypeError, "all .* must be Distribution instances"):
distributions_py.Mixture(cat, [None])
# components tuples must be size 2
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [tuple()])
# components tuples must be size 2
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(None)])
# components tuples must be of the form (callable, dict)
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(None, None)])
# components tuples must be size 2
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(None, None, None)])
# components tuples must be of the form (callable, dict)
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(lambda x: x, None)])
# components tuples must be of the form (callable, dict)
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(None, {})])
with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"):
distributions_py.Mixture(
cat,
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]}),
(distributions_py.Normal, {"mu": [np.float16(1.0)],
"sigma": [np.float16(2.0)]})])
[distributions_py.Normal(mu=[1.0], sigma=[2.0]),
distributions_py.Normal(mu=[np.float16(1.0)],
sigma=[np.float16(2.0)])])
with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"):
distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None)
with self.assertRaisesWithPredicateMatch(TypeError,
"either be continuous or not"):
distributions_py.Mixture(
cat,
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]}),
(functools.partial(distributions_py.Bernoulli, dtype=tf.float32),
{"logits": [1.0]})])
[distributions_py.Normal(mu=[1.0], sigma=[2.0]),
distributions_py.Bernoulli(dtype=tf.float32, logits=[1.0])])
def testMeanUnivariate(self):
with self.test_session() as sess:
@ -196,7 +177,7 @@ class MixtureTest(tf.test.TestCase):
self.assertEqual(batch_shape, mean.get_shape())
cat_probs = tf.nn.softmax(dist.cat.logits)
dist_means = [d.mean() for d in dist.distributions]
dist_means = [d.mean() for d in dist.components]
mean_value, cat_probs_value, dist_means_value = sess.run(
[mean, cat_probs, dist_means])
@ -217,7 +198,7 @@ class MixtureTest(tf.test.TestCase):
self.assertEqual(batch_shape + (4,), mean.get_shape())
cat_probs = tf.nn.softmax(dist.cat.logits)
dist_means = [d.mean() for d in dist.distributions]
dist_means = [d.mean() for d in dist.components]
mean_value, cat_probs_value, dist_means_value = sess.run(
[mean, cat_probs, dist_means])
@ -243,7 +224,7 @@ class MixtureTest(tf.test.TestCase):
self.assertEqual(x.shape, p_x.get_shape())
cat_probs = tf.nn.softmax([dist.cat.logits])[0]
dist_probs = [d.prob(x) for d in dist.distributions]
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
@ -269,7 +250,7 @@ class MixtureTest(tf.test.TestCase):
self.assertEqual(x.shape[:-1], p_x.get_shape())
cat_probs = tf.nn.softmax([dist.cat.logits])[0]
dist_probs = [d.prob(x) for d in dist.distributions]
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
@ -292,7 +273,7 @@ class MixtureTest(tf.test.TestCase):
self.assertEqual(x.shape, p_x.get_shape())
cat_probs = tf.nn.softmax(dist.cat.logits)
dist_probs = [d.prob(x) for d in dist.distributions]
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
@ -318,7 +299,7 @@ class MixtureTest(tf.test.TestCase):
self.assertEqual(x.shape[:-1], p_x.get_shape())
cat_probs = tf.nn.softmax(dist.cat.logits)
dist_probs = [d.prob(x) for d in dist.distributions]
dist_probs = [d.prob(x) for d in dist.components]
p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
@ -430,7 +411,7 @@ class MixtureTest(tf.test.TestCase):
self.assertEqual(batch_shape, entropy_lower_bound.get_shape())
cat_probs = tf.nn.softmax(dist.cat.logits)
dist_entropy = [d.entropy() for d in dist.distributions]
dist_entropy = [d.entropy() for d in dist.components]
entropy_lower_bound_value, cat_probs_value, dist_entropy_value = (
sess.run([entropy_lower_bound, cat_probs, dist_entropy]))
@ -486,8 +467,7 @@ class MixtureBenchmark(tf.test.Benchmark):
tf.Variable(np.random.rand(batch_size, num_features))
for _ in range(num_components)]
components = list(
(distributions_py.MultivariateNormalDiag,
{"mu": mu, "diag_stdev": sigma})
distributions_py.MultivariateNormalDiag(mu=mu, diag_stdev=sigma)
for (mu, sigma) in zip(mus, sigmas))
return distributions_py.Mixture(cat, components)
@ -524,8 +504,7 @@ class MixtureBenchmark(tf.test.Benchmark):
psd(np.random.rand(batch_size, num_features, num_features)))
for _ in range(num_components)]
components = list(
(distributions_py.MultivariateNormalFull,
{"mu": mu, "sigma": sigma})
distributions_py.MultivariateNormalFull(mu=mu, sigma=sigma)
for (mu, sigma) in zip(mus, sigmas))
return distributions_py.Mixture(cat, components)

View File

@ -33,7 +33,7 @@ class QuantizedDistributionTest(tf.test.TestCase):
self.assertTrue(np.isfinite(array).all())
def test_quantization_of_uniform_with_cutoffs_having_no_effect(self):
with self.test_session():
with self.test_session() as sess:
# The Quantized uniform with cutoffs == None divides the real line into:
# R = ...(-1, 0](0, 1](1, 2](2, 3](3, 4]...
# j = ... 0 1 2 3 4 ...
@ -60,34 +60,38 @@ class QuantizedDistributionTest(tf.test.TestCase):
b=3.0)
# pmf
pmf_n1, pmf_0, pmf_1, pmf_2, pmf_3, pmf_4, pmf_5 = sess.run(
qdist.pmf([-1., 0., 1., 2., 3., 4., 5.]))
# uniform had no mass below -1.
self.assertAllClose(0., qdist.pmf(-1.).eval())
self.assertAllClose(0., pmf_n1)
# uniform had no mass below 0.
self.assertAllClose(0., qdist.pmf(0.).eval())
self.assertAllClose(0., pmf_0)
# uniform put 1/3 of its mass in each of (0, 1], (1, 2], (2, 3],
# which are the intervals j = 1, 2, 3.
self.assertAllClose(1 / 3, qdist.pmf(1.).eval())
self.assertAllClose(1 / 3, qdist.pmf(2.).eval())
self.assertAllClose(1 / 3, qdist.pmf(3.).eval())
self.assertAllClose(1 / 3, pmf_1)
self.assertAllClose(1 / 3, pmf_2)
self.assertAllClose(1 / 3, pmf_3)
# uniform had no mass in (3, 4] or (4, 5], which are j = 4, 5.
self.assertAllClose(0 / 3, qdist.pmf(4.).eval())
self.assertAllClose(0 / 3, qdist.pmf(5.).eval())
self.assertAllClose(0 / 3, pmf_4)
self.assertAllClose(0 / 3, pmf_5)
# cdf
self.assertAllClose(0., qdist.cdf(-1.).eval())
self.assertAllClose(0., qdist.cdf(0.).eval())
self.assertAllClose(1 / 3, qdist.cdf(1.).eval())
self.assertAllClose(2 / 3, qdist.cdf(2.).eval())
cdf_n1, cdf_0, cdf_1, cdf_2, cdf_2p5, cdf_3, cdf_4, cdf_5 = sess.run(
qdist.cdf([-1., 0., 1., 2., 2.5, 3., 4., 5.]))
self.assertAllClose(0., cdf_n1)
self.assertAllClose(0., cdf_0)
self.assertAllClose(1 / 3, cdf_1)
self.assertAllClose(2 / 3, cdf_2)
# Note fractional values allowed for cdfs of discrete distributions.
# And adding 0.5 makes no difference because the quantized dist has
# mass only on the integers, never in between.
self.assertAllClose(2 / 3, qdist.cdf(2.5).eval())
self.assertAllClose(3 / 3, qdist.cdf(3.).eval())
self.assertAllClose(3 / 3, qdist.cdf(4.).eval())
self.assertAllClose(3 / 3, qdist.cdf(5.).eval())
self.assertAllClose(2 / 3, cdf_2p5)
self.assertAllClose(3 / 3, cdf_3)
self.assertAllClose(3 / 3, cdf_4)
self.assertAllClose(3 / 3, cdf_5)
def test_quantization_of_uniform_with_cutoffs_in_the_middle(self):
with self.test_session():
with self.test_session() as sess:
# The uniform is supported on [-3, 3]
# Consider partitions the real line in intervals
# ...(-3, -2](-2, -1](-1, 0](0, 1](1, 2](2, 3] ...
@ -103,25 +107,27 @@ class QuantizedDistributionTest(tf.test.TestCase):
b=3.0)
# pmf
cdf_n3, cdf_n2, cdf_n1, cdf_0, cdf_0p5, cdf_1, cdf_10 = sess.run(
qdist.cdf([-3., -2., -1., 0., 0.5, 1.0, 10.0]))
# Uniform had no mass on (-4, -3] or (-3, -2]
self.assertAllClose(0., qdist.cdf(-3.).eval())
self.assertAllClose(0., qdist.cdf(-2.).eval())
self.assertAllClose(0., cdf_n3)
self.assertAllClose(0., cdf_n2)
# Uniform had 1/6 of its mass in each of (-3, -2], and (-2, -1], which
# were collapsed into (-infty, -1], which is now the "-1" interval.
self.assertAllClose(1 / 3, qdist.cdf(-1.).eval())
self.assertAllClose(1 / 3, cdf_n1)
# The j=0 interval contained mass from (-3, 0], which is 1/2 of the
# uniform's mass.
self.assertAllClose(1 / 2, qdist.cdf(0.).eval())
self.assertAllClose(1 / 2, cdf_0)
# Adding 0.5 makes no difference because the quantized dist has mass on
# the integers, not in between them.
self.assertAllClose(1 / 2, qdist.cdf(0.5).eval())
self.assertAllClose(1 / 2, cdf_0p5)
# After applying the cutoff, all mass was either in the interval
# (0, infty), or below. (0, infty) is the interval indexed by j=1,
# so pmf(1) should equal 1.
self.assertAllClose(1., qdist.cdf(1.0).eval())
self.assertAllClose(1., cdf_1)
# Since no mass of qdist is above 1,
# pmf(10) = P[Y <= 10] = P[Y <= 1] = pmf(1).
self.assertAllClose(1., qdist.cdf(10.0).eval())
self.assertAllClose(1., cdf_10)
def test_quantization_of_batch_of_uniforms(self):
batch_shape = (5, 5)
@ -231,10 +237,12 @@ class QuantizedDistributionTest(tf.test.TestCase):
# The smallest value the samples can take on is 1, which corresponds to
# the interval (0, 1]. Recall we use ceiling in the sampling definition.
self.assertLess(0.5, samps.min())
for x in range(1, 10):
x_vals = np.arange(1, 11).astype(np.float32)
pmf_vals = qdist.pmf(x_vals).eval()
for ii in range(10):
self.assertAllClose(
qdist.pmf(float(x)).eval(),
(samps == x).mean(),
pmf_vals[ii],
(samps == x_vals[ii]).mean(),
atol=std_err_bound)
def test_normal_cdf_and_survival_function(self):

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -32,8 +31,8 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
"""Get the KL-divergence KL(dist_a || dist_b).
Args:
dist_a: instance of distributions.Distribution.
dist_b: instance of distributions.Distribution.
dist_a: The first distribution.
dist_b: The second distribution.
allow_nan: If `False` (default), a runtime error is raised
if the KL returns NaN values for any batch entry of the given
distributions. If `True`, the KL may return a NaN for the given entry.
@ -43,18 +42,9 @@ def kl(dist_a, dist_b, allow_nan=False, name=None):
A Tensor with the batchwise KL-divergence between dist_a and dist_b.
Raises:
TypeError: If dist_a or dist_b is not an instance of Distribution.
NotImplementedError: If no KL method is defined for distribution types
of dist_a and dist_b.
"""
if not isinstance(dist_a, distribution.Distribution):
raise TypeError(
"dist_a is not an instance of Distribution, received type: %s"
% type(dist_a))
if not isinstance(dist_b, distribution.Distribution):
raise TypeError(
"dist_b is not an instance of Distribution, received type: %s"
% type(dist_b))
kl_fn = _DIVERGENCES.get((type(dist_a), type(dist_b)), None)
if kl_fn is None:
raise NotImplementedError(
@ -94,16 +84,7 @@ class RegisterKL(object):
Args:
dist_cls_a: the class of the first argument of the KL divergence.
dist_cls_b: the class of the second argument of the KL divergence.
Raises:
TypeError: if dist_cls_a or dist_cls_b are not subclasses of
Distribution.
"""
if not issubclass(dist_cls_a, distribution.Distribution):
raise TypeError("%s is not a subclass of Distribution" % dist_cls_a)
if not issubclass(dist_cls_b, distribution.Distribution):
raise TypeError("%s is not a subclass of Distribution" % dist_cls_b)
self._key = (dist_cls_a, dist_cls_b)
def __call__(self, kl_fn):

View File

@ -56,43 +56,15 @@ class Mixture(distribution.Distribution):
all having matching dtype, batch shape, event shape, and continuity
properties (the components).
The user does not pass the list of distributions directly, but rather a
list of `(constructor, batch_tensor_params_dict)` pairs,
called `components`. The list of distributions is created via:
```python
distributions = [
c(**params_dict) for (c, params_dict) in zip(*components)
]
```
This form allows for certain types of batch-shape optimizations within
this class.
An example of `components`:
```python
components = [
(tf.contrib.distributions.Normal, {"mu": 3.0, "sigma": 1.0}),
(functools.partial(tf.contrib.distributions.Normal, validate_args=False),
{"mu": 3.0, "sigma": 2.0}),
(tf.contrib.distributions.Normal.from_params,
{"mu": 1.0, "sigma": -1.0})
]
```
The `num_classes` of `cat` must be possible to infer at graph construction
time and match `len(distributions)`.
time and match `len(components)`.
Args:
cat: A `Categorical` distribution instance, representing the probabilities
of `distributions`.
components: A list or tuple of `(constructor, batch_tensor_params)`
tuples. The `constructor` must be a callable, and `batch_tensor_params`
must be a dict mapping constructor kwargs to batchwise parameters.
Each `Distribution` instance created by calling
`constructor(**batch_tensor_params)` must have the same type, be defined
on the same domain, and have matching `event_shape` and `batch_shape`.
components: A list or tuple of `Distribution` instances.
Each instance must have the same type, be defined on the same domain,
and have matching `event_shape` and `batch_shape`.
validate_args: `Boolean`, default `False`. If `True`, raise a runtime
error if batch or event ranks are inconsistent between cat and any of
the distributions. This is only checked if the ranks cannot be
@ -106,16 +78,13 @@ class Mixture(distribution.Distribution):
Raises:
TypeError: If cat is not a `Categorical`, or `components` is not
a list or tuple, or the elements of `components` are not
tuples of the form `(callable, dict)`, or the objects resulting
from calling `callable(**dict)` are not instances of `Distribution`, or
the resulting instances of `Distribution` do not have matching
continuity properties, or do not have matching `dtype`.
ValueError: If `components` is an empty list or tuple, or the
distributions created from `components` do have a statically known event
rank. If `cat.num_classes` cannot be inferred at graph creation time,
instances of `Distribution`, or do not have matching `dtype`.
ValueError: If `components` is an empty list or tuple, or its
elements do not have a statically known event rank.
If `cat.num_classes` cannot be inferred at graph creation time,
or the constant value of `cat.num_classes` is not equal to
`len(distributions)`, or all `distributions` and `cat` do not have
matching static batch shapes, or all components' distributions do not
`len(components)`, or all `components` and `cat` do not have
matching static batch shapes, or all components do not
have matching static event shapes.
"""
if not isinstance(cat, categorical.Categorical):
@ -126,52 +95,29 @@ class Mixture(distribution.Distribution):
if not isinstance(components, (list, tuple)):
raise TypeError("components must be a list or tuple, but saw: %s" %
components)
if not all(isinstance(c, tuple) and len(c) == 2 and
callable(c[0]) and isinstance(c[1], dict)
for c in components):
if not all(isinstance(c, distribution.Distribution) for c in components):
raise TypeError(
"all entries in components must be tuples of the form "
"(make, params), where make is callable and params is a dict,"
"all entries in components must be Distribution instances"
" but saw: %s" % components)
def _make_tensors(d):
return dict((k, ops.convert_to_tensor(v, name="tensor_%s" % k))
for (k, v) in d.items())
with ops.name_scope(name, values=[cat.logits]):
components_tensor_params = list((make, _make_tensors(batch_params))
for (make, batch_params) in components)
distributions = [make(**batch_params)
for (make, batch_params) in components_tensor_params]
# Store components internally with their batch params having been
# converted to tensors.
# TODO(ebrevdo): Use self._components to optimize sampling.
self._components = components_tensor_params
if not all(isinstance(d, distribution.Distribution) for d in distributions):
dtype = components[0].dtype
if not all(d.dtype == dtype for d in components):
raise TypeError("All components must have the same dtype, but saw "
"dtypes: %s" % [(d.name, d.dtype) for d in components])
is_continuous = components[0].is_continuous
if not all(d.is_continuous == is_continuous for d in components):
raise TypeError(
"all entries in distributions must be instances of Distribution, "
"but saw: %s" % distributions)
dtype = distributions[0].dtype
if not all(d.dtype == dtype for d in distributions):
raise TypeError("All distributions must have the same dtype, but saw "
"dtypes: %s" % [(d.name, d.dtype) for d in distributions])
is_continuous = distributions[0].is_continuous
if not all(d.is_continuous == is_continuous for d in distributions):
raise TypeError(
"All distributions must either be continuous or not, but continuity "
"values are: %s" % [(d.name, d.is_continuous) for d in distributions])
static_event_shape = distributions[0].get_event_shape()
"All components must either be continuous or not, but continuity "
"values are: %s" % [(d.name, d.is_continuous) for d in components])
static_event_shape = components[0].get_event_shape()
static_batch_shape = cat.get_batch_shape()
for d in distributions:
for d in components:
static_event_shape = static_event_shape.merge_with(d.get_event_shape())
static_batch_shape = static_batch_shape.merge_with(d.get_batch_shape())
if static_event_shape.ndims is None:
raise ValueError(
"Expected to know rank(event_shape) from distributions, but "
"none of the distributions provide a static number of ndims")
"Expected to know rank(event_shape) from components, but "
"none of the components provide a static number of ndims")
# Ensure that all batch and event ndims are consistent.
with ops.name_scope(name, values=[cat.logits]):
@ -180,42 +126,42 @@ class Mixture(distribution.Distribution):
if static_num_components is None:
raise ValueError(
"Could not infer number of classes from cat and unable "
"to compare this value to the number of distributions passed in.")
"to compare this value to the number of components passed in.")
# Possibly convert from numpy 0-D array.
static_num_components = int(static_num_components)
if static_num_components != len(distributions):
raise ValueError("cat.num_classes != len(distributions): %d vs. %d" %
(static_num_components, len(distributions)))
if static_num_components != len(components):
raise ValueError("cat.num_classes != len(components): %d vs. %d" %
(static_num_components, len(components)))
cat_batch_shape = cat.batch_shape()
cat_batch_rank = array_ops.size(cat_batch_shape)
if validate_args:
batch_shapes = [d.batch_shape() for d in distributions]
batch_shapes = [d.batch_shape() for d in components]
batch_ranks = [array_ops.size(bs) for bs in batch_shapes]
check_message = ("distributions[%d] batch shape must match cat "
check_message = ("components[%d] batch shape must match cat "
"batch shape")
self._assertions = [
check_ops.assert_equal(
cat_batch_rank, batch_ranks[di], message=check_message % di)
for di in range(len(distributions))
for di in range(len(components))
]
self._assertions += [
check_ops.assert_equal(
cat_batch_shape, batch_shapes[di], message=check_message % di)
for di in range(len(distributions))
for di in range(len(components))
]
else:
self._assertions = []
self._cat = cat
self._distributions = list(distributions)
self._components = list(components)
self._num_components = static_num_components
self._static_event_shape = static_event_shape
self._static_batch_shape = static_batch_shape
super(Mixture, self).__init__(
dtype=dtype,
parameters={"cat": self._cat, "distributions": self._distributions,
parameters={"cat": self._cat, "components": self._components,
"num_components": self._num_components},
is_reparameterized=False,
is_continuous=is_continuous,
@ -228,8 +174,8 @@ class Mixture(distribution.Distribution):
return self._cat
@property
def distributions(self):
return self._distributions
def components(self):
return self._components
@property
def num_components(self):
@ -242,14 +188,14 @@ class Mixture(distribution.Distribution):
return self._static_batch_shape
def _event_shape(self):
return self._distributions[0].event_shape()
return self._components[0].event_shape()
def _get_event_shape(self):
return self._static_event_shape
def _mean(self):
with ops.control_dependencies(self._assertions):
distribution_means = [d.mean() for d in self.distributions]
distribution_means = [d.mean() for d in self.components]
cat_probs = self._cat_probs(log_probs=False)
# This was checked to not be None at construction time.
static_event_rank = self.get_event_shape().ndims
@ -271,7 +217,7 @@ class Mixture(distribution.Distribution):
def _log_prob(self, x):
with ops.control_dependencies(self._assertions):
x = ops.convert_to_tensor(x, name="x")
distribution_log_probs = [d.log_prob(x) for d in self.distributions]
distribution_log_probs = [d.log_prob(x) for d in self.components]
cat_log_probs = self._cat_probs(log_probs=True)
final_log_probs = [
cat_lp + d_lp
@ -351,7 +297,7 @@ class Mixture(distribution.Distribution):
samples_class = [None for _ in range(self.num_components)]
for c in range(self.num_components):
n_class = array_ops.size(partitioned_samples_indices[c])
samples_class_c = self.distributions[c].sample_n(n_class, seed=seed)
samples_class_c = self.components[c].sample_n(n_class, seed=seed)
# Pull out the correct batch entries from each index.
# To do this, we may have to flatten the batch shape.
@ -395,7 +341,7 @@ class Mixture(distribution.Distribution):
r"""A lower bound on the entropy of this mixture model.
The bound below is not always very tight, and its usefulness depends
on the mixture probabilities and the distributions in use.
on the mixture probabilities and the components in use.
A lower bound is useful for ELBO when the `Mixture` is the variational
distribution:
@ -432,7 +378,7 @@ class Mixture(distribution.Distribution):
"""
with self._name_scope(name, values=[self.cat.logits]):
with ops.control_dependencies(self._assertions):
distribution_entropies = [d.entropy() for d in self.distributions]
distribution_entropies = [d.entropy() for d in self.components]
cat_probs = self._cat_probs(log_probs=False)
partial_entropies = [
c_p * m for (c_p, m) in zip(cat_probs, distribution_entropies)

View File

@ -35,7 +35,7 @@ namespace {
// The complete set of audio file formats that are supported by the op. These
// strings are defined by FFmpeg and documented here:
// https://www.ffmpeg.org/ffmpeg-formats.html
const char* kValidFileFormats[] = {"mp3", "ogg", "wav"};
const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"};
// Writes binary data to a file.
Status WriteFile(const string& filename, tensorflow::StringPiece contents) {

View File

@ -61,10 +61,30 @@ class DecodeAudioOpTest(tf.test.TestCase):
self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1)
self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 2)
def testMonoMp4Mp3Codec(self):
# mp3 compressed audio streams in mp4 container.
self._loadFileAndTest('mono_16khz_mp3.mp4', 'mp4', 2.77, 20000, 1)
self._loadFileAndTest('mono_16khz_mp3.mp4', 'mp4', 2.77, 20000, 2)
def testMonoMp4AacCodec(self):
# aac compressed audio streams in mp4 container.
self._loadFileAndTest('mono_32khz_aac.mp4', 'mp4', 2.77, 20000, 1)
self._loadFileAndTest('mono_32khz_aac.mp4', 'mp4', 2.77, 20000, 2)
def testStereoMp3(self):
self._loadFileAndTest('stereo_48khz.mp3', 'mp3', 0.79, 50000, 1)
self._loadFileAndTest('stereo_48khz.mp3', 'mp3', 0.79, 20000, 2)
def testStereoMp4Mp3Codec(self):
# mp3 compressed audio streams in mp4 container.
self._loadFileAndTest('stereo_48khz_mp3.mp4', 'mp4', 0.79, 50000, 1)
self._loadFileAndTest('stereo_48khz_mp3.mp4', 'mp4', 0.79, 20000, 2)
def testStereoMp4AacCodec(self):
# aac compressed audio streams in mp4 container.
self._loadFileAndTest('stereo_48khz_aac.mp4', 'mp4', 0.79, 50000, 1)
self._loadFileAndTest('stereo_48khz_aac.mp4', 'mp4', 0.79, 20000, 2)
def testMonoWav(self):
self._loadFileAndTest('mono_10khz.wav', 'wav', 0.57, 5000, 1)
self._loadFileAndTest('mono_10khz.wav', 'wav', 0.57, 10000, 4)

View File

@ -35,11 +35,14 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
channel_count=None):
"""Create an op that decodes the contents of an audio file.
Note that ffmpeg is free to select the "best" audio track from an mp4.
https://trac.ffmpeg.org/wiki/Map
Args:
contents: The binary contents of the audio file to decode. This is a
scalar.
file_format: A string specifying which format the contents will conform
to. This can be mp3, ogg, or wav.
to. This can be mp3, mp4, ogg, or wav.
samples_per_second: The number of samples per second that is assumed.
In some cases, resampling will occur to generate the correct sample
rate.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -125,7 +125,8 @@ class ClassifierTest(tf.test.TestCase):
default_signature = signatures.default_signature
return default_signature
def testExportMonitorRegressionSignature(self):
# Disable this test case until b/31032996 is fixed.
def _testExportMonitorRegressionSignature(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Classifier(model_fn=logistic_model_fn, n_classes=3)
export_dir = tempfile.mkdtemp() + 'export/'

View File

@ -19,11 +19,269 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from tensorflow.contrib import layers
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn import session_run_hook
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.utils import checkpoints
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.contrib.losses.python.losses import loss_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import training as train
class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
_CENTERED_BIAS = "centered_bias"
_CENTERED_BIAS_WEIGHT = "centered_bias_weight"
_CLASSES = "classes"
_LOGISTIC = "logistic"
_PROBABILITIES = "probabilities"
# The default learning rate of 0.05 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
_LEARNING_RATE = 0.05
def _as_iterable(preds, output):
for pred in preds:
yield pred[output]
def _get_feature_dict(features):
if isinstance(features, dict):
return features
return {"": features}
def _get_optimizer(optimizer):
if callable(optimizer):
return optimizer()
else:
return optimizer
def _add_hidden_layer_summary(value, tag):
logging_ops.scalar_summary("%s:fraction_of_zero_values" % tag,
nn.zero_fraction(value))
logging_ops.histogram_summary("%s:activation" % tag, value)
def _centered_bias(num_label_columns):
centered_bias = variables.Variable(
array_ops.zeros([num_label_columns]),
collections=[_CENTERED_BIAS, ops.GraphKeys.VARIABLES],
name=_CENTERED_BIAS_WEIGHT)
logging_ops.scalar_summary(
["centered_bias %d" % cb for cb in range(num_label_columns)],
array_ops.reshape(centered_bias, [-1]))
return centered_bias
def _centered_bias_step(targets, loss_fn, num_label_columns):
centered_bias = ops.get_collection(_CENTERED_BIAS)
batch_size = array_ops.shape(targets)[0]
logits = array_ops.reshape(
array_ops.tile(centered_bias[0], [batch_size]),
[batch_size, num_label_columns])
loss = loss_fn(logits, targets)
return train.AdagradOptimizer(0.1).minimize(loss, var_list=centered_bias)
def _get_weight_tensor(features, weight_column_name):
"""Returns the weight tensor of shape [batch_size] or 1."""
if weight_column_name is None:
return 1.0
else:
return array_ops.reshape(
math_ops.to_float(features[weight_column_name]),
shape=(-1,))
def _rescale_eval_loss(loss, weights):
"""Rescales evaluation loss according to the given weights.
The rescaling is needed because in the training loss weights are not
considered in the denominator, whereas for the evaluation loss we should
divide by the sum of weights.
The rescaling factor is:
R = sum_{i} 1 / sum_{i} w_{i}
Args:
loss: the scalar weighted loss.
weights: weight coefficients. Either a scalar, or a `Tensor` of shape
[batch_size].
Returns:
The given loss multiplied by the rescaling factor.
"""
rescaling_factor = math_ops.reduce_mean(weights)
return math_ops.div(loss, rescaling_factor)
def _predictions(logits, n_classes):
"""Returns predictions for the given logits and n_classes."""
predictions = {}
if n_classes == 2:
predictions[_LOGISTIC] = math_ops.sigmoid(logits)
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
predictions[_PROBABILITIES] = nn.softmax(logits)
predictions[_CLASSES] = array_ops.reshape(
math_ops.argmax(logits, 1), shape=(-1, 1))
return predictions
def _dnn_classifier_model_fn(features, targets, mode, params):
"""Deep Neural Net model_fn.
Args:
features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`).
targets: `Tensor` of shape [batch_size, 1] or [batch_size] target labels of
dtype `int32` or `int64` in the range `[0, n_classes)`.
mode: Defines whether this is training, evaluation or prediction.
See `ModeKeys`.
params: A dict of hyperparameters.
The following hyperparameters are expected:
* hidden_units: List of hidden units per layer.
* feature_columns: An iterable containing all the feature columns used by
the model.
* n_classes: number of target classes.
* weight_column_name: A string defining the weight feature column, or
None if there are no weights.
* optimizer: string, `Optimizer` object, or callable that defines the
optimizer to use for training.
* activation_fn: Activation function applied to each layer. If `None`,
will use `tf.nn.relu`.
* dropout: When not `None`, the probability we will drop out a given
coordinate.
* gradient_clip_norm: A float > 0. If provided, gradients are
clipped to their global norm with this clipping ratio.
* enable_centered_bias: A bool. If True, estimator will learn a centered
bias variable for each class. Rest of the model structure learns the
residual after centered bias.
* num_ps_replicas: The number of parameter server replicas.
Returns:
predictions: A dict of `Tensor` objects.
loss: A scalar containing the loss of the step.
train_op: The op for training.
"""
hidden_units = params["hidden_units"]
feature_columns = params["feature_columns"]
n_classes = params["n_classes"]
weight_column_name = params["weight_column_name"]
optimizer = params["optimizer"]
activation_fn = params["activation_fn"]
dropout = params["dropout"]
gradient_clip_norm = params["gradient_clip_norm"]
enable_centered_bias = params["enable_centered_bias"]
num_ps_replicas = params["num_ps_replicas"]
features = _get_feature_dict(features)
parent_scope = "dnn"
num_label_columns = 1 if n_classes == 2 else n_classes
if n_classes == 2:
loss_fn = loss_ops.sigmoid_cross_entropy
else:
loss_fn = loss_ops.sparse_softmax_cross_entropy
input_layer_partitioner = (
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
with variable_scope.variable_scope(
parent_scope + "/input_from_feature_columns",
values=features.values(),
partitioner=input_layer_partitioner) as scope:
net = layers.input_from_feature_columns(
columns_to_tensors=features,
feature_columns=feature_columns,
weight_collections=[parent_scope],
scope=scope)
hidden_layer_partitioner = (
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas))
for layer_id, num_hidden_units in enumerate(hidden_units):
with variable_scope.variable_scope(
parent_scope + "/hiddenlayer_%d" % layer_id,
values=[net],
partitioner=hidden_layer_partitioner) as scope:
net = layers.fully_connected(
net,
num_hidden_units,
activation_fn=activation_fn,
variables_collections=[parent_scope],
scope=scope)
if dropout is not None and mode == estimator.ModeKeys.TRAIN:
net = layers.dropout(
net,
keep_prob=(1.0 - dropout))
_add_hidden_layer_summary(net, scope.name)
with variable_scope.variable_scope(
parent_scope + "/logits",
values=[net],
partitioner=hidden_layer_partitioner) as scope:
logits = layers.fully_connected(
net,
num_label_columns,
activation_fn=None,
variables_collections=[parent_scope],
scope=scope)
_add_hidden_layer_summary(logits, scope.name)
if enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(num_label_columns))
if mode == estimator.ModeKeys.TRAIN:
loss = loss_fn(logits, targets,
weight=_get_weight_tensor(features, weight_column_name))
train_ops = [optimizers.optimize_loss(
loss=loss, global_step=contrib_variables.get_global_step(),
learning_rate=_LEARNING_RATE, optimizer=_get_optimizer(optimizer),
clip_gradients=gradient_clip_norm, name=parent_scope)]
if enable_centered_bias:
train_ops.append(_centered_bias_step(targets, loss_fn, num_label_columns))
return None, loss, control_flow_ops.group(*train_ops)
elif mode == estimator.ModeKeys.EVAL:
predictions = _predictions(logits=logits, n_classes=n_classes)
weight = _get_weight_tensor(features, weight_column_name)
training_loss = loss_fn(logits, targets, weight=weight)
loss = _rescale_eval_loss(training_loss, weight)
return predictions, loss, []
else: # mode == estimator.ModeKeys.INFER:
predictions = _predictions(logits=logits, n_classes=n_classes)
return predictions, None, []
class DNNClassifier(evaluable.Evaluable, trainable.Trainable):
"""A classifier for TensorFlow DNN models.
Example:
@ -124,36 +382,211 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
Returns:
A `DNNClassifier` estimator.
Raises:
ValueError: If `n_classes` < 2.
"""
if enable_centered_bias is None:
enable_centered_bias = True
dnn_linear_combined._changing_default_center_bias() # pylint: disable=protected-access
super(DNNClassifier, self).__init__(
model_dir=model_dir,
n_classes=n_classes,
weight_column_name=weight_column_name,
dnn_feature_columns=feature_columns,
dnn_optimizer=optimizer,
dnn_hidden_units=hidden_units,
dnn_activation_fn=activation_fn,
dnn_dropout=dropout,
gradient_clip_norm=gradient_clip_norm,
enable_centered_bias=enable_centered_bias,
config=config)
self.feature_columns = feature_columns
self.optimizer = optimizer
self.activation_fn = activation_fn
self.dropout = dropout
self.hidden_units = hidden_units
self._feature_columns_inferred = False
self._hidden_units = hidden_units
self._feature_columns = feature_columns
self._model_dir = model_dir or tempfile.mkdtemp()
if n_classes <= 1:
raise ValueError(
"Classification requires n_classes >= 2. Given: {}".format(n_classes))
self._n_classes = n_classes
self._weight_column_name = weight_column_name
optimizer = optimizer or "Adagrad"
num_ps_replicas = config.num_ps_replicas if config else 0
self._estimator = estimator.Estimator(
model_fn=_dnn_classifier_model_fn,
model_dir=self._model_dir,
config=config,
params={
"hidden_units": hidden_units,
"feature_columns": feature_columns,
"n_classes": n_classes,
"weight_column_name": weight_column_name,
"optimizer": optimizer,
"activation_fn": activation_fn,
"dropout": dropout,
"gradient_clip_norm": gradient_clip_norm,
"enable_centered_bias": enable_centered_bias,
"num_ps_replicas": num_ps_replicas,
})
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None):
"""See trainable.Trainable."""
# TODO(roumposg): Remove when deprecated monitors are removed.
if monitors is not None:
deprecated_monitors = [
m for m in monitors
if not isinstance(m, session_run_hook.SessionRunHook)
]
for monitor in deprecated_monitors:
monitor.set_estimator(self)
monitor._lock_estimator() # pylint: disable=protected-access
result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
batch_size=batch_size, monitors=monitors,
max_steps=max_steps)
if monitors is not None:
for monitor in deprecated_monitors:
monitor._unlock_estimator() # pylint: disable=protected-access
return result
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
batch_size=None, steps=None, metrics=None, name=None):
"""See evaluable.Evaluable."""
if metrics is None:
metrics = {}
metrics.update({
"accuracy": metric_spec.MetricSpec(
metric_fn=metrics_lib.streaming_accuracy,
prediction_key=_CLASSES,
weight_key=self._weight_column_name)})
if self._n_classes == 2:
metrics.update({
"auc": metric_spec.MetricSpec(
metric_fn=metrics_lib.streaming_auc,
prediction_key=_LOGISTIC,
weight_key=self._weight_column_name)})
return self._estimator.evaluate(
x=x, y=y, input_fn=input_fn, feed_fn=feed_fn, batch_size=batch_size,
steps=steps, metrics=metrics, name=name)
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
as_iterable=False)
def predict(self, x=None, input_fn=None, batch_size=None, as_iterable=False):
"""Returns predicted classes for given features.
Args:
x: features.
input_fn: Input function. If set, x must be None.
batch_size: Override default batch size.
as_iterable: If True, return an iterable which keeps yielding predictions
for each example until inputs are exhausted. Note: The inputs must
terminate if you want the iterable to terminate (e.g. be sure to pass
num_epochs=1 if you are using something like read_batch_features).
Returns:
Numpy array of predicted classes (or an iterable of predicted classes if
as_iterable is True).
"""
preds = self._estimator.predict(x=x, input_fn=input_fn,
batch_size=batch_size, outputs=[_CLASSES],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=_CLASSES)
return preds[_CLASSES].reshape(-1)
@deprecated_arg_values(
estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
as_iterable=False)
def predict_proba(
self, x=None, input_fn=None, batch_size=None, as_iterable=False):
"""Returns prediction probabilities for given features.
Args:
x: features.
input_fn: Input function. If set, x and y must be None.
batch_size: Override default batch size.
as_iterable: If True, return an iterable which keeps yielding predictions
for each example until inputs are exhausted. Note: The inputs must
terminate if you want the iterable to terminate (e.g. be sure to pass
num_epochs=1 if you are using something like read_batch_features).
Returns:
Numpy array of predicted probabilities (or an iterable of predicted
probabilities if as_iterable is True).
"""
preds = self._estimator.predict(x=x, input_fn=input_fn,
batch_size=batch_size,
outputs=[_PROBABILITIES],
as_iterable=as_iterable)
if as_iterable:
return _as_iterable(preds, output=_PROBABILITIES)
return preds[_PROBABILITIES]
def get_variable_names(self):
"""Returns list of all variable names in this model.
Returns:
List of names.
"""
return [name for name, _ in list_variables(self._model_dir)]
def get_variable_value(self, name):
"""Returns value of the variable given by name.
Args:
name: string, name of the tensor.
Returns:
`Tensor` object.
"""
return load_variable(self._model_dir, name)
def export(self,
export_dir,
input_fn=None,
input_feature_key=None,
use_deprecated_input_fn=True,
signature_fn=None,
default_batch_size=1,
exports_to_keep=None):
"""See BasEstimator.export."""
def default_input_fn(unused_estimator, examples):
return layers.parse_feature_columns_from_examples(
examples, self._feature_columns)
self._estimator.export(
export_dir=export_dir,
input_fn=input_fn or default_input_fn,
input_feature_key=input_feature_key,
use_deprecated_input_fn=use_deprecated_input_fn,
signature_fn=(
signature_fn or export.classification_signature_fn_with_prob),
prediction_key=_PROBABILITIES,
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
@property
def model_dir(self):
return self._model_dir
@property
@deprecated("2016-10-13", "This method inspects the private state of the "
"object, and should not be used")
def weights_(self):
return self.dnn_weights_
hiddenlayer_weights = [checkpoints.load_variable(
self._model_dir, name=("dnn/hiddenlayer_%d/weights" % i))
for i, _ in enumerate(self._hidden_units)]
logits_weights = [checkpoints.load_variable(
self._model_dir, name="dnn/logits/weights")]
return hiddenlayer_weights + logits_weights
@property
@deprecated("2016-10-13", "This method inspects the private state of the "
"object, and should not be used")
def bias_(self):
return self.dnn_bias_
hiddenlayer_bias = [checkpoints.load_variable(
self._model_dir, name=("dnn/hiddenlayer_%d/biases" % i))
for i, _ in enumerate(self._hidden_units)]
logits_bias = [checkpoints.load_variable(
self._model_dir, name="dnn/logits/biases")]
centered_bias = [checkpoints.load_variable(
self._model_dir, name=_CENTERED_BIAS_WEIGHT)]
return hiddenlayer_bias + logits_bias + centered_bias
@property
def config(self):
return self._estimator.config
class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):

View File

@ -27,13 +27,8 @@ import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
# pylint: disable=g-import-not-at-top
try:
from sklearn.cross_validation import cross_val_score
HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
from tensorflow.python.ops import math_ops
def _prepare_iris_data_for_logistic_regression():
@ -350,6 +345,7 @@ class DNNClassifierTest(tf.test.TestCase):
# For the case of binary classification, the 2nd column of "predictions"
# denotes the model predictions.
predictions = tf.slice(predictions, [0, 1], [-1, 1])
targets = math_ops.cast(targets, predictions.dtype)
return tf.reduce_sum(tf.mul(predictions, targets))
classifier = tf.contrib.learn.DNNClassifier(
@ -362,9 +358,15 @@ class DNNClassifierTest(tf.test.TestCase):
input_fn=_input_fn_train,
steps=100,
metrics={
'my_accuracy': tf.contrib.metrics.streaming_accuracy,
('my_precision', 'classes'): tf.contrib.metrics.streaming_precision,
('my_metric', 'probabilities'): _my_metric_op
'my_accuracy': MetricSpec(
metric_fn=tf.contrib.metrics.streaming_accuracy,
prediction_key='classes'),
'my_precision': MetricSpec(
metric_fn=tf.contrib.metrics.streaming_precision,
prediction_key='classes'),
'my_metric': MetricSpec(
metric_fn=_my_metric_op,
prediction_key='probabilities')
})
self.assertTrue(
set(['loss', 'my_accuracy', 'my_precision', 'my_metric'
@ -375,21 +377,14 @@ class DNNClassifierTest(tf.test.TestCase):
# Test the case where the 2nd element of the key is neither "classes" nor
# "probabilities".
with self.assertRaises(ValueError):
classifier.evaluate(
input_fn=_input_fn_train,
steps=100,
metrics={('bad_name', 'bad_type'): tf.contrib.metrics.streaming_auc})
# Test the case where the tuple of the key doesn't have 2 elements.
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(KeyError, 'bad_type'):
classifier.evaluate(
input_fn=_input_fn_train,
steps=100,
metrics={
('bad_length_name', 'classes', 'bad_length'):
tf.contrib.metrics.streaming_accuracy
})
'bad_name': MetricSpec(
metric_fn=tf.contrib.metrics.streaming_auc,
prediction_key='bad_type')})
def testTrainSaveLoad(self):
"""Tests that insures you can save and reload a trained model."""
@ -466,6 +461,31 @@ class DNNClassifierTest(tf.test.TestCase):
self.assertGreater(scores['accuracy'], 0.9)
self.assertLess(scores['loss'], 0.3)
def testExport(self):
"""Tests export model for servo."""
def input_fn():
return {
'age': tf.constant([1]),
'language': tf.SparseTensor(values=['english'],
indices=[[0, 0]],
shape=[1, 1])
}, tf.constant([[1]])
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
feature_columns = [
tf.contrib.layers.real_valued_column('age'),
tf.contrib.layers.embedding_column(language, dimension=1)
]
classifier = tf.contrib.learn.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[3, 3])
classifier.fit(input_fn=input_fn, steps=100)
export_dir = tempfile.mkdtemp()
classifier.export(export_dir)
def testDisableCenteredBias(self):
"""Tests that we can disable centered bias."""
cont_features = [
@ -484,32 +504,6 @@ class DNNClassifierTest(tf.test.TestCase):
self.assertGreater(scores['accuracy'], 0.8)
self.assertLess(scores['loss'], 0.3)
def testSklearnCompatibility(self):
"""Tests compatibility with sklearn"""
if not HAS_SKLEARN:
return
iris = tf.contrib.learn.datasets.load_iris()
cont_features = [
tf.contrib.layers.real_valued_column('', dimension=4)]
kwargs = {
'n_classes': 3,
'feature_columns': cont_features,
'optimizer' : 'Adam',
'hidden_units' : [3, 4]
}
classifier = tf.contrib.learn.DNNClassifier(**kwargs)
scores = cross_val_score(
classifier,
iris.data[1:5],
iris.target[1:5],
scoring='accuracy',
fit_params={'steps': 100}
)
self.assertAllClose(scores, [1, 1, 1])
class DNNRegressorTest(tf.test.TestCase):

View File

@ -234,7 +234,7 @@ def read_keyed_batch_features(file_pattern,
queue_capacity=10000,
reader_num_threads=1,
feature_queue_capacity=100,
num_enqueue_threads=2,
num_queue_runners=2,
parser_num_threads=None,
parse_fn=None,
name=None):
@ -266,8 +266,8 @@ def read_keyed_batch_features(file_pattern,
queue_capacity: Capacity for input queue.
reader_num_threads: The number of threads to read examples.
feature_queue_capacity: Capacity of the parsed features queue.
num_enqueue_threads: Number of threads to enqueue the parsed example queue.
Using multiple threads to enqueue the parsed example queue helps maintain
num_queue_runners: Number of queue runners to start for the feature queue,
Adding multiple queue runners for the parsed example queue helps maintain
a full queue when the subsequent computations overall are cheaper than
parsing.
parser_num_threads: (Deprecated) The number of threads to parse examples.
@ -300,14 +300,14 @@ def read_keyed_batch_features(file_pattern,
feature_map,
keys=keys,
feature_queue_capacity=feature_queue_capacity,
num_enqueue_threads=num_enqueue_threads,
num_queue_runners=num_queue_runners,
name=scope)
def queue_parsed_features(parsed_features,
keys=None,
feature_queue_capacity=100,
num_enqueue_threads=2,
num_queue_runners=2,
name=None):
"""Speeds up parsing by using queues to do it asynchronously.
@ -326,8 +326,8 @@ def queue_parsed_features(parsed_features,
parsed_features: A dict of string key to `Tensor` or `SparseTensor` objects.
keys: `Tensor` of string keys.
feature_queue_capacity: Capacity of the parsed features queue.
num_enqueue_threads: Number of threads to enqueue the parsed example queue.
Using multiple thrads to enqueue the parsed example queue helps maintain
num_queue_runners: Number of queue runners to start for the feature queue,
Adding multiple queue runners for the parsed example queue helps maintain
a full queue when the subsequent computations overall are cheaper than
parsing.
name: Name of resulting op.
@ -374,14 +374,14 @@ def queue_parsed_features(parsed_features,
math_ops.cast(input_queue.size(), dtypes.float32)
* (1. / feature_queue_capacity))
# Use multiple threads to enqueue so the queue is always full. Adding more
# than two threads may hog the cpu on the worker to fill up the queue.
enqueue_ops = [input_queue.enqueue(tensors_to_enqueue)
for _ in range(num_enqueue_threads)]
queue_runner.add_queue_runner(queue_runner.QueueRunner(
input_queue, enqueue_ops,
queue_closed_exception_types=(errors.OutOfRangeError,
errors.CancelledError)))
# Add multiple queue runners so that the queue is always full. Adding more
# than two queue-runners may hog the cpu on the worker to fill up the queue.
for _ in range(num_queue_runners):
queue_runner.add_queue_runner(
queue_runner.QueueRunner(
input_queue, [input_queue.enqueue(tensors_to_enqueue)],
queue_closed_exception_types=(errors.OutOfRangeError,
errors.CancelledError)))
dequeued_tensors = input_queue.dequeue()

View File

@ -83,8 +83,6 @@ def run(experiment_fn, output_dir, schedule=None):
# Get the schedule
config = experiment.estimator.config
schedule = schedule or _get_default_schedule(config)
if not schedule:
raise ValueError('Must specify a schedule')
# Execute the schedule
if not hasattr(experiment, schedule):
@ -107,19 +105,36 @@ def run(experiment_fn, output_dir, schedule=None):
return task()
def _is_distributed(config):
"""Returns true if this is a distributed job."""
if not config.cluster_spec:
return False
# This is considered a distributed job if there is more than one task
# in the cluster spec.
task_count = 0
for job in config.cluster_spec.jobs:
for _ in config.cluster_spec.job_tasks(job):
task_count += 1
return task_count > 1
def _get_default_schedule(config):
"""Returns the default schedule for the provided RunConfig."""
if not config or not config.job_name:
return None
if not config or not _is_distributed(config):
return 'local_run'
if not config.job_name or config.job_name == 'master':
# TODO(rhaertel): handle the case there are more
# than one masters or explicitly disallow.
if not config.job_name:
raise ValueError('Must specify a schedule')
if config.job_name == 'master':
# TODO(rhaertel): handle the case where there is more than one master
# or explicitly disallow such a case.
return 'local_run'
elif config.job_name == 'ps':
return 'run_std_server'
elif config.job_name == 'worker':
return 'train'
return ValueError('No default schedule for task type: %s' %
(config.job_name,))
raise ValueError('No default schedule for task type: %s' % (config.job_name,))

View File

@ -335,7 +335,7 @@ def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn, bidirectional,
fw_cell, attn_length=attn_length, attn_size=attn_size,
attn_vec_size=attn_vec_size, state_is_tuple=False)
bw_cell = contrib_rnn.AttentionCellWrapper(
fw_cell, attn_length=attn_length, attn_size=attn_size,
bw_cell, attn_length=attn_length, attn_size=attn_size,
attn_vec_size=attn_vec_size, state_is_tuple=False)
rnn_fw_cell = nn.rnn_cell.MultiRNNCell([fw_cell] * num_layers,
state_is_tuple=False)

View File

@ -39,7 +39,7 @@ class TestExperiment(tf.contrib.learn.Experiment):
return Estimator()
def local_run(self):
return "train_and_evaluate"
return "local_run"
def train(self):
return "train"
@ -62,6 +62,18 @@ def build_non_experiment(output_dir):
# pylint: enable=unused-argument
def build_distributed_cluster_spec():
return tf.train.ClusterSpec(
{"ps": ["localhost:1234", "localhost:1235"],
"worker": ["localhost:1236", "localhost:1237"],
"master": ["localhost:1238"],
"foo_has_no_default_schedule": ["localhost:1239"]})
def build_non_distributed_cluster_spec():
return tf.train.ClusterSpec({"foo": ["localhost:1234"]})
class MainTest(tf.test.TestCase):
def setUp(self):
@ -76,7 +88,9 @@ class MainTest(tf.test.TestCase):
schedule="simple_task"))
def test_schedule_from_tf_config(self):
os.environ["TF_CONFIG"] = json.dumps({"task": {"type": "worker"}})
os.environ["TF_CONFIG"] = json.dumps(
{"cluster": build_distributed_cluster_spec().as_dict(),
"task": {"type": "worker"}})
# RunConfig constructuor will set job_name from TF_CONFIG.
config = run_config.RunConfig()
self.assertEqual(
@ -85,28 +99,35 @@ class MainTest(tf.test.TestCase):
output_dir="/tmp"))
def test_schedule_from_manually_specified_job_name(self):
config = run_config.RunConfig(job_name="worker")
config = run_config.RunConfig(
job_name="worker", cluster_spec=build_distributed_cluster_spec())
self.assertEqual(
"train",
learn_runner.run(lambda output_dir: TestExperiment(config=config),
output_dir="/tmp"))
def test_schedule_from_config_runs_train_and_evaluate_on_master(self):
config = run_config.RunConfig(job_name="master", task=0, is_chief=True)
def test_schedule_from_config_runs_local_run_on_master(self):
config = run_config.RunConfig(
job_name="master",
cluster_spec=build_distributed_cluster_spec(),
task=0,
is_chief=True)
self.assertEqual(
"train_and_evaluate",
"local_run",
learn_runner.run(lambda output_dir: TestExperiment(config=config),
output_dir="/tmp"))
def test_schedule_from_config_runs_serve_on_ps(self):
config = run_config.RunConfig(job_name="ps")
config = run_config.RunConfig(
job_name="ps", cluster_spec=build_distributed_cluster_spec())
self.assertEqual(
"run_std_server",
learn_runner.run(lambda output_dir: TestExperiment(config=config),
output_dir="/tmp"))
def test_schedule_from_config_runs_train_on_worker(self):
config = run_config.RunConfig(job_name="worker")
config = run_config.RunConfig(
job_name="worker", cluster_spec=build_distributed_cluster_spec())
self.assertEqual(
"train",
learn_runner.run(lambda output_dir: TestExperiment(config=config),
@ -117,13 +138,27 @@ class MainTest(tf.test.TestCase):
learn_runner.run, build_experiment, "",
"simple_task")
def test_fail_no_schedule_and_no_config(self):
self.assertRaisesRegexp(ValueError, "Must specify a schedule",
learn_runner.run, build_experiment, "/tmp")
def test_no_schedule_and_no_config_runs_local_run(self):
self.assertEqual(
"local_run",
learn_runner.run(build_experiment,
output_dir="/tmp"))
def test_no_schedule_and_non_distributed_runs_local_run(self):
config = run_config.RunConfig(
cluster_spec=build_non_distributed_cluster_spec())
self.assertEqual(
"local_run",
learn_runner.run(lambda output_dir: TestExperiment(config=config),
output_dir="/tmp"))
def test_fail_job_name_with_no_default_schedule(self):
self.assertRaisesRegexp(ValueError, "Must specify a schedule",
learn_runner.run, build_experiment, "/tmp")
config = run_config.RunConfig(
job_name="foo_has_no_default_schedule",
cluster_spec=build_distributed_cluster_spec())
create_experiment_fn = lambda output_dir: TestExperiment(config=config)
self.assertRaisesRegexp(ValueError, "No default schedule",
learn_runner.run, create_experiment_fn, "/tmp")
def test_fail_non_callable(self):
self.assertRaisesRegexp(TypeError, "Experiment builder .* is not callable",
@ -148,7 +183,8 @@ class MainTest(tf.test.TestCase):
"default")
def test_fail_schedule_from_config_with_no_job_name(self):
config = run_config.RunConfig(job_name=None)
config = run_config.RunConfig(
job_name=None, cluster_spec=build_distributed_cluster_spec())
self.assertRaisesRegexp(
ValueError,
"Must specify a schedule",

View File

@ -77,13 +77,6 @@ Node* Ones(Graph* const g, const int n) {
return test::graph::Constant(g, data);
}
Node* StringIota(Graph* const g, const int n) {
Tensor data(DT_STRING, TensorShape({n}));
test::FillFn<string>(
&data, [](const int i) { return strings::StrCat(strings::Hex(i)); });
return test::graph::Constant(g, data);
}
Node* SparseExampleIndices(Graph* const g, const int sparse_features_per_group,
const int num_examples) {
const int x_size = num_examples * 4;

View File

@ -60,9 +60,6 @@ HOST_OBJDIR := $(MAKEFILE_DIR)/gen/host_obj/
HOST_BINDIR := $(MAKEFILE_DIR)/gen/host_bin/
HOST_GENDIR := $(MAKEFILE_DIR)/gen/host_obj/
# Find the current Eigen version from the Bazel configuration
EIGEN_VERSION := $(shell grep eigen_version tensorflow/workspace.bzl | head -1 | sed -e 's/.*eigen_version.*=.*"\(.*\)"/\1/')
# Settings for the host compiler.
HOST_CXX := $(CC_PREFIX) gcc
HOST_CXXFLAGS := --std=c++11
@ -75,7 +72,7 @@ HOST_LDOPTS += -L/usr/local/lib
HOST_INCLUDES := \
-I. \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen-eigen-$(EIGEN_VERSION) \
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(HOST_GENDIR)
ifeq ($(HAS_GEN_HOST_PROTOC),true)
HOST_INCLUDES += -I$(MAKEFILE_DIR)/gen/protobuf-host/include
@ -148,7 +145,7 @@ DEPFLAGS = -MT $@ -MMD -MP -MF $(DEPDIR)/$*.Td
INCLUDES := \
-I. \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen-eigen-$(EIGEN_VERSION) \
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(PROTOGENDIR) \
-I$(PBTGENDIR)
ifeq ($(HAS_GEN_HOST_PROTOC),true)
@ -240,7 +237,7 @@ ifeq ($(TARGET),ANDROID)
-I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/armeabi/include \
-I. \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen-eigen-$(EIGEN_VERSION) \
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(MAKEFILE_DIR)/gen/protobuf/include \
-I$(PROTOGENDIR) \
-I$(PBTGENDIR)
@ -570,6 +567,12 @@ clean:
rm -rf $(MAKEFILE_DIR)/gen
rm -rf tensorflow/core/util/version_info.cc
# Gets rid of all generated files except protobuf libs generated
# before calling make. This allows users not to recompile proto libs everytime.
clean_except_protobuf_libs:
find $(MAKEFILE_DIR)/gen -mindepth 1 -maxdepth 1 ! -name "protobuf" ! -name "protobuf-host" -exec rm -r "{}" \;
rm -rf tensorflow/core/util/version_info.cc
# Gets rid of target files only, leaving the host alone. Also leaves the lib
# directory untouched deliberately, so we can persist multiple architectures
# across builds for iOS.

View File

@ -46,18 +46,19 @@ shift $((OPTIND - 1))
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd ${SCRIPT_DIR}/../../../
# Remove any old files first.
make -f tensorflow/contrib/makefile/Makefile clean
if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then
# Remove any old files first.
make -f tensorflow/contrib/makefile/Makefile clean
rm -rf tensorflow/contrib/makefile/downloads
# Pull down the required versions of the frameworks we need.
tensorflow/contrib/makefile/download_dependencies.sh
fi
# Compile protobuf for the target Android device architectures.
# Compile protobuf for the target Android device architectures.
CC_PREFIX="${CC_PREFIX}" NDK_ROOT="${NDK_ROOT}" \
tensorflow/contrib/makefile/compile_android_protobuf.sh -c
else
# Only clean files generated by make
make -f tensorflow/contrib/makefile/Makefile clean_except_protobuf_libs
fi
if [[ "${USE_HEXAGON}" == "true" ]]; then
HEXAGON_PARENT_DIR=$(cd ../hexagon && pwd)

View File

@ -1,4 +1,4 @@
#!/bin/bash -ex
#!/bin/bash
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -14,57 +14,52 @@
# limitations under the License.
# ==============================================================================
set -e
DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads
BZL_FILE_PATH=tensorflow/workspace.bzl
mkdir -p ${DOWNLOADS_DIR}
EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/.*tar\.gz' "${BZL_FILE_PATH}")"
GEMMLOWP_URL="$(grep -o 'http.*github.com/google/gemmlowp/.*tar\.gz' "${BZL_FILE_PATH}")"
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
PROTOBUF_URL="$(grep -o 'http.*github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}")"
RE2_URL="$(grep -o 'http.*github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}")"
# Grab the current Eigen version name from the Bazel build file
EIGEN_HASH=$(cat "${BZL_FILE_PATH}" | egrep "eigen_version.*=.*\".*\"" | awk '{ print $3 }')
# Trim trailing and preceding double quotes
EIGEN_HASH="${EIGEN_HASH%\"}"
EIGEN_HASH="${EIGEN_HASH#\"}"
if [[ -z "${EIGEN_HASH}" ]]; then
echo >&2 "Eigen hash does not exist."
exit 1
else
echo "Eigen hash = ${EIGEN_HASH}"
fi
curl "https://bitbucket.org/eigen/eigen/get/${EIGEN_HASH}.tar.gz" \
-o /tmp/eigen-${EIGEN_HASH}.tar.gz
tar xzf /tmp/eigen-${EIGEN_HASH}.tar.gz -C ${DOWNLOADS_DIR}
# Link to the downloaded Eigen library from a permanent directory name, since
# the downloaded name changes with every version.
cd ${DOWNLOADS_DIR}
rm -rf eigen-latest
ln -s eigen-eigen-${EIGEN_HASH} eigen-latest
# TODO(petewarden) - Some new code in Eigen triggers a clang bug with iOS arm64,
# so work around it by patching the source.
function replace_by_sed() {
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
# so work around it by patching the source.
replace_by_sed() {
local regex="${1}"
shift
if echo "${OSTYPE}" | grep -q darwin; then
sed -e "$1" -i '' "$2"
sed -i '' -e "${regex}" "$@"
else
sed -e "$1" -i "$2"
sed -i -e "${regex}" "$@"
fi
}
download_and_extract() {
local usage="Usage: download_and_extract URL DIR"
local url="${1:?${usage}}"
local dir="${2:?${usage}}"
echo "downloading ${url}" >&2
mkdir -p "${dir}"
tar -C "${dir}" --strip-components=1 -xz < <(curl -Ls "${url}")
}
download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen"
download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp"
download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest"
download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf"
download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2"
replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \
eigen-latest/Eigen/src/Core/arch/NEON/Complex.h
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \
eigen-latest/Eigen/src/Core/arch/NEON/Complex.h
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \
eigen-latest/Eigen/src/Core/arch/NEON/Complex.h
git clone https://github.com/google/re2.git re2
git clone https://github.com/google/gemmlowp.git gemmlowp
git clone https://github.com/google/protobuf.git protobuf
git clone https://github.com/google/googletest.git googletest
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
# TODO(satok): Remove this once protobuf/autogen.sh is fixed.
replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#http://download.tensorflow.org/deps/gmock-1.7.0.zip#' \
protobuf/autogen.sh
"${DOWNLOADS_DIR}/protobuf/autogen.sh"
echo "download_dependencies.sh completed successfully."
echo "download_dependencies.sh completed successfully." >&2

View File

@ -35,7 +35,6 @@ tensorflow/core/lib/io/record_writer.cc
tensorflow/core/lib/io/record_reader.cc
tensorflow/core/lib/io/random_inputstream.cc
tensorflow/core/lib/io/path.cc
tensorflow/core/lib/io/match.cc
tensorflow/core/lib/io/iterator.cc
tensorflow/core/lib/io/inputstream_interface.cc
tensorflow/core/lib/io/inputbuffer.cc

View File

@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/meta_graph.pb.h
tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/tensor_bundle.pb.h
tensorflow/core/lib/core/error_codes.pb.h
tensorflow/core/framework/versions.pb.h
tensorflow/core/framework/variable.pb.h

View File

@ -2,6 +2,7 @@ tensorflow/core/util/saved_tensor_slice.pb_text.cc
tensorflow/core/util/memmapped_file_system.pb_text.cc
tensorflow/core/protobuf/saver.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/tensor_bundle.pb_text.cc
tensorflow/core/lib/core/error_codes.pb_text.cc
tensorflow/core/framework/versions.pb_text.cc
tensorflow/core/framework/types.pb_text.cc

View File

@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/tensor_bundle.proto
tensorflow/core/lib/core/error_codes.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/variable.proto

View File

@ -118,6 +118,7 @@ time.
@@streaming_mean_cosine_distance
@@streaming_percentage_less
@@streaming_sensitivity_at_specificity
@@streaming_sparse_average_precision_at_k
@@streaming_sparse_precision_at_k
@@streaming_sparse_recall_at_k
@@streaming_specificity_at_sensitivity
@ -167,6 +168,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_recall_at
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_recall_at_thresholds
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_root_mean_squared_error
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sensitivity_at_specificity
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_average_precision_at_k
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_precision_at_k
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_sparse_recall_at_k
from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_specificity_at_sensitivity

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -46,7 +46,7 @@ class QuantizedMatMulOpForHexagonTest : public OpsTestBase {
<< ", hexagon binary version = "
<< hexagon_gemm_wrapper_GetHexagonBinaryVersion() << ")";
LOG(INFO) << "Cpu frequency = "
<< profile_utils::CpuUtils::GetCpuFrequency();
<< profile_utils::CpuUtils::GetCycleCounterFrequency();
#else
LOG(WARNING) << "Hexagon libs are not linked.";
#endif

View File

@ -663,7 +663,7 @@ def train(train_op,
raise ValueError('Cannot provide trace_every_n_steps because '
'logdir=None')
if sync_optimizer and startup_delay_steps > 0:
if sync_optimizer is not None and startup_delay_steps > 0:
raise ValueError(
'startup_delay_steps must be zero when sync_optimizer is supplied.')
@ -697,7 +697,7 @@ def train(train_op,
cleanup_op = None
if is_chief and sync_optimizer:
if is_chief and sync_optimizer is not None:
if not isinstance(sync_optimizer,
sync_replicas_optimizer.SyncReplicasOptimizer):
raise ValueError(
@ -761,7 +761,7 @@ def train(train_op,
number_of_steps or sys.maxint))
sv.start_queue_runners(sess)
logging.info('Starting Queues.')
if is_chief and sync_optimizer:
if is_chief and sync_optimizer is not None:
sv.start_queue_runners(sess, [chief_queue_runner])
try:
while not sv.should_stop():

View File

@ -0,0 +1,68 @@
# Description:
# TensorBoard module containing volatile or experimental code.
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
# For platform specific build config
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
go_api_version = 2,
visibility = ["//visibility:public"],
)
# API methods in `tf.contrib.tensorboard` package.
py_library(
name = "tensorboard",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [":plugins"],
)
# API methods in `tf.contrib.tensorboard.plugins` package.
py_library(
name = "plugins",
srcs = ["plugins/__init__.py"],
srcs_version = "PY2AND3",
deps = [":projector"],
)
# API methods and protos in `tf.contrib.tensorboard.plugins.projector` package.
py_library(
name = "projector",
srcs = ["plugins/projector/__init__.py"],
srcs_version = "PY2AND3",
deps = [
":protos_all_py",
"//tensorflow/python:lib",
],
)
py_test(
name = "projector_api_test",
size = "small",
srcs = ["plugins/projector/projector_api_test.py"],
srcs_version = "PY2AND3",
deps = [
":projector",
"//tensorflow:tensorflow_py",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,22 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""tensorboard module containing volatile or experimental code."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Add projects here, they will show up under tf.contrib.tensorboard.
from tensorflow.contrib.tensorboard import plugins

View File

@ -0,0 +1,22 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""tensorboard plugins module containing volatile or experimental code."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Add projects here, they will show up under tf.contrib.tensorboard.plugins
from tensorflow.contrib.tensorboard.plugins import projector

View File

@ -0,0 +1,54 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Public API for the Embedding Projector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from google.protobuf import text_format
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
from tensorflow.python.lib.io import file_io
PROJECTOR_FILENAME = 'projector_config.pbtxt'
def visualize_embeddings(summary_writer, config):
"""Stores a config file used by the embedding projector.
Args:
summary_writer: The summary writer used for writting events.
config: `tf.contrib.tensorboard.plugins.projector.ProjectorConfig`
proto that holds the configuration for the projector such as paths to
checkpoint files and metadata files for the embeddings. If
`config.model_checkpoint_path` is none, it defaults to the
`logdir` used by the summary_writer.
Raises:
ValueError: If the summary writer does not have a `logdir`.
"""
logdir = summary_writer.get_logdir()
# Sanity checks.
if logdir is None:
raise ValueError('Summary writer must have a logdir')
# Saving the config file in the logdir.
config_pbtxt = text_format.MessageToString(config)
file_io.write_string_to_file(
os.path.join(logdir, PROJECTOR_FILENAME), config_pbtxt)

View File

@ -0,0 +1,49 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""API tests for the projector plugin in TensorBoard."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import tensorflow as tf
from google.protobuf import text_format
class ProjectorApiTest(tf.test.TestCase):
def testVisualizeEmbeddings(self):
# Create a dummy configuration.
config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
config.model_checkpoint_path = 'test'
emb1 = config.embedding.add()
emb1.tensor_name = 'tensor1'
emb1.metadata_path = 'metadata1'
# Call the API method to save the configuration to a temporary dir.
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
writer = tf.train.SummaryWriter(temp_dir)
tf.contrib.tensorboard.plugins.projector.visualize_embeddings(writer,
config)
# Read the configuratin from disk and make sure it matches the original.
with tf.gfile.GFile(os.path.join(temp_dir, 'projector_config.pbtxt')) as f:
config2 = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
text_format.Parse(f.read(), config2)
self.assertEqual(config, config2)

View File

@ -13,19 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/file_system.h"
syntax = "proto3";
#include "tensorflow/core/platform/test.h"
package tensorflow;
namespace tensorflow {
namespace {
TEST(FileSystemTest, GetNameFromURI) {
EXPECT_EQ("foo", GetNameFromURI("file://foo"));
EXPECT_EQ("file:/", GetNameFromURI("file:/"));
EXPECT_EQ("file:", GetNameFromURI("file:"));
EXPECT_EQ("bar", GetNameFromURI("bar"));
message EmbeddingInfo {
string tensor_name = 1;
string metadata_path = 2;
}
} // namespace
} // namespace tensorflow
message ProjectorConfig {
string model_checkpoint_path = 1;
repeated EmbeddingInfo embedding = 2;
}

View File

@ -0,0 +1,453 @@
# tfprof: A Profiling Tool for TensorFlow Models
go/tfprof
Author: Xin Pan (xpan@google.com, github: panyx0718)
Consultants: Jon Shlens (shlens@google.com), Pete Warden (petewarden@google.com)
[TOC]
## Introduction
tfprof is a profiling tool for TensorFlow that analyzes model architectures
and measures system performance.
###Major Features
1. Measure model parameters, float operations, tensor shapes.
2. Measure op execution times, requested memory size and device placement.
3. Inspect checkpoint tensors' shapes and their values.
4. Explore model based on name scope or graph structure.
5. Selectively grouping/filtering/accounting/ordering ops.
### Interfaces
[CLI Tutorials](#cli-tutorials):
It supports interactive mode for exploration and single-shot mode for
scripts. Outputs can be dumped to files or printed in terminal.
Python API Tutorials: Python API is not released yet.
## CLI Tutorials
Tutorials are based on a 32 layers ResNet.
TODO(xpan): Provide graph.pbtxt, model.ckpt, tfprof_log and run_meta download.
### Examples
1) Start `tfprof` command line tool
```shell
# Build the tool.
bazel build -c opt tensorflow/contrib/tfprof/...
# Help information, including detail 'option' instructions.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof help
#
# The following commands will start tfprof interactive mode.
#
# Profile model shapes and parameters only.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path=/graph.pbtxt
#
# Additionally profile checkpoint statistics and values.
# Use '-account_type_regexes _checkpoint_variables' to select
# checkpoint tensors.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path=graph.pbtxt \
--checkpoint_path=model.ckpt
#
# Additionally profile ops requested memory and timing.
# See CLI Input Files section on generating run_meta file.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path=graph.pbtxt \
--run_meta_path=run_meta \
--checkpoint_path=model.ckpt
#
# tfprof_log is used to define customized op types and float ops.
# Use tfprof_logger.write_op_log() to create tfprof_log.
# See 11) in Examples section on generating tfprof_log file.
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path=graph.pbtxt \
--run_meta_path=run_meta \
--op_log_path=tfprof_log \
--checkpoint_path=model.ckpt
```
Note that `graph.pbtxt` is an ASCII text format.
2) Press enter to show the default options
```shell
tfprof>
tfprof>
-max_depth 4
-min_bytes 0
-min_micros 0
-min_params 0
-min_float_ops 0
-device_regexes .*
-order_by name
-account_type_regexes Variable
-start_name_regexes .*
-trim_name_regexes
-show_name_regexes .*
-hide_name_regexes IsVariableInitialized_[0-9]+,save\/.*,^zeros[0-9_]*
-account_displayed_op_only false
# supported select fileds. Availability depends on --[run_meta|checkpoint|op_log]_path.
# [bytes|micros|params|float_ops|num_hidden_ops|tensor_value|device|op_types]
-select params
-viz false
-dump_to_file
```
3) I want to see the `BatchNorm`'s gamma value in checkpoint.
```shell
# Requires --graph_path, --checkpoint_path.
tfprof> scope -show_name_regexes unit_1_0.*gamma -select tensor_value -max_depth 5
_TFProfRoot ()
unit_1_0/shared_activation/init_bn/gamma ()
[1.80 2.10 2.06 1.91 2.26 1.86 1.81 1.37 1.78 1.85 1.96 1.54 2.04 2.34 2.22 1.99 ],
unit_1_0/sub2/bn2/gamma ()
[1.57 1.83 1.30 1.25 1.59 1.14 1.26 0.82 1.19 1.10 1.48 1.01 0.82 1.23 1.21 1.14 ],
```
4) I want to see my checkpoint tensors shape and number of parameters.
```shell
# Requires --graph_path, --checkpoint_path.
# Increase -max_depth to see all tensors.
tfprof> scope -account_type_regexes _checkpoint_variables -select params -max_depth 4
_TFProfRoot (--/930.58k params)
global_step (0/0 params)
init/init_conv/DW (3x3x3x16, 432/864 params)
pool_logit/DW (64x10, 640/1.28k params)
pool_logit/DW/Momentum (64x10, 640/640 params)
pool_logit/biases (10, 10/20 params)
pool_logit/biases/Momentum (10, 10/10 params)
unit_last/final_bn/beta (64, 64/128 params)
unit_last/final_bn/gamma (64, 64/128 params)
unit_last/final_bn/moving_mean (64, 64/64 params)
unit_last/final_bn/moving_variance (64, 64/64 params)
```
5) I defined an op named cost to calculate the loss. I want to know what ops
it depends on take a long time to run. Hint: Use the graph command to explore
graph dependencies.
```shell
# Requires --graph_path, --run_meta_path.
tfprof> graph -start_name_regexes cost.* -max_depth 100 -min_micros 10000 -select micros -account_type_regexes .*
_TFProfRoot (0us/3.61sec)
init/init_conv/Conv2D (11.75ms/3.10sec)
random_shuffle_queue_DequeueMany (3.09sec/3.09sec)
unit_1_0/sub2/conv2/Conv2D (74.14ms/3.19sec)
unit_1_3/sub2/conv2/Conv2D (60.75ms/3.34sec)
unit_2_4/sub2/conv2/Conv2D (73.58ms/3.54sec)
unit_3_3/sub2/conv2/Conv2D (10.26ms/3.60sec)
```
6) I want to know the expensive operations during the back propagation.
Hint: tensorflow prepend gradient to your defined name scopes. Use the scope
command to explore based on name scope hierarchies.
```shell
# Requires --graph_path, --run_meta_path.
tfprof> scope -start_name_regexes gradient.* -max_depth 100 -min_micros 20000 -select micros -account_type_regexes .*
_TFProfRoot (0us/2.29sec)
gradients/unit_1_0/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (54.96ms/54.96ms)
gradients/unit_1_0/sub2/conv2/Conv2D_grad/Conv2DBackpropFilter (83.63ms/83.63ms)
gradients/unit_1_1/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (99.25ms/99.25ms)
gradients/unit_1_2/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (95.40ms/95.40ms)
gradients/unit_1_2/sub2/conv2/Conv2D_grad/Conv2DBackpropFilter (99.83ms/99.83ms)
gradients/unit_1_3/sub1/conv1/Conv2D_grad/Conv2DBackpropFilter (95.39ms/95.39ms)
...
```
7) Show the number of float operations in the model.
Note: float operations calculation depends on
1) op.RegisterStatistics. If an op doesnt
have RegisterStatistics defined, its float operations cannot be counted.
2) fully defined shape is also necessary in order to calculate flops.
float operations number is provided by tensorflow::tfprof::OpLog logged from
Python API.
```shell
# Requires --graph_path, --op_log_path.
tfprof> scope -min_float_ops 1 -max_depth 10 -select float_ops -account_type_regexes .*
_TFProfRoot (0/17.63b flops)
gradients/pool_logit/xw_plus_b/MatMul_grad/MatMul (163.84k/163.84k flops)
gradients/pool_logit/xw_plus_b/MatMul_grad/MatMul_1 (163.84k/163.84k flops)
init/init_conv/Conv2D (113.25m/113.25m flops)
pool_logit/xw_plus_b (1.28k/165.12k flops)
pool_logit/xw_plus_b/MatMul (163.84k/163.84k flops)
unit_1_0/sub1/conv1/Conv2D (603.98m/603.98m flops)
unit_1_0/sub2/conv2/Conv2D (603.98m/603.98m flops)
unit_1_1/sub1/conv1/Conv2D (603.98m/603.98m flops)
unit_1_1/sub2/conv2/Conv2D (603.98m/603.98m flops)
...
```
8) Show the number of parameters of all `tf.trainable_variables()` in the model.
```shell
# Requires --graph_path --op_log_path.
# store option for future commands.
tfprof> set -account_type_regexes _trainable_variables
tfprof> scope -max_depth 4 -select params
_TFProfRoot (--/464.15k params)
init/init_conv/DW (3x3x3x16, 432/432 params)
pool_logit/DW (64x10, 640/640 params)
pool_logit/biases (10, 10/10 params)
unit_last/final_bn/beta (64, 64/64 params)
unit_last/final_bn/gamma (64, 64/64 params)
```
Where does “_trainable_variables” come from? It is from the OpLog file
generated by write_op_log() Python API. write_op_log() help users create some
common op types implicitly. Users can define their own op types and log it
through the write_op_log() API.
9) What if Im lazy and dont want to define op type? I have given my ops
well-defined names in my models code. And want to use names to select a group
of ops. Lets try it!
```shell
tfprof> set -account_type_regexes .*
tfprof> scope -show_name_regexes unit_2_1.*DW -max_depth 100 -account_displayed_op_only
_TFProfRoot (0/18.43k params)
unit_2_1/sub1/conv1/DW (3x3x32x32, 9.22k/9.22k params)
unit_2_1/sub2/conv2/DW (3x3x32x32, 9.22k/9.22k params)
```
The above command allows you to filter ops that match specific names.
`-account_displayed_op_only` asks tfprof to only account ops displayed
in terminal. Otherwise, tfprof accounts all ops matched by
`-account_type_regexes` recursively even if they are hidden due to some
options such as -max_depth.
10) TensorFlow has built-in op types. For example, built-in op type `Variable`
seems to include `Variable's` created by your model. However, be careful when
depending on it because TensorFlow creates extra `Variable` ops implicitly and
the implicitly created ops can have the same prefix as the `Variable's` you
defined.
In the following example, extra `Variables` are created and “/Momentum” is
appended to their names. This might cause you “model capacity” calculation
to get wrong.
```shell
tfprof> scope -account_type_regexes Variable -max_depth 4 -select params
_TFProfRoot (--/930.58k params)
global_step (1/1 params)
init/init_conv/DW (3x3x3x16, 432/864 params)
pool_logit/DW (64x10, 640/1.28k params)
pool_logit/DW/Momentum (64x10, 640/640 params)
pool_logit/biases (10, 10/20 params)
pool_logit/biases/Momentum (10, 10/10 params)
unit_last/final_bn/beta (64, 64/128 params)
unit_last/final_bn/gamma (64, 64/128 params)
unit_last/final_bn/moving_mean (64, 64/64 params)
unit_last/final_bn/moving_variance (64, 64/64 params)
```
11) A example of defining extra op type for ops using `OpLog`
First, in Python code, create an `OpLog` proto and add op type
information to it:
```python
op_log = tfprof_log_pb2.OpLog()
entry = op_log.log_entries.add()
entry.name = 'pool_logit/DW'
entry.types.append('pool_logit')
entry = op_log.log_entries.add()
entry.name = 'pool_logit/biases'
# Alternatively:
# var = tf.get_variable(xxx)
# entry.name = var.op.name
entry.types.append('pool_logit')
```
Second, call write_op_log to write the OpLog proto.
```python
tfprof_logger.write_op_log(sess.graph, /tmp/my_op_log_dir, op_log)
```
Third, when starting the tfprof tool, specify
"--op_log_path /tmp/my_op_log_dir/op_log"
```shell
tfprof> scope -account_type_regexes pool_logit -max_depth 4 -select params
_TFProfRoot (--/650 params)
pool_logit/DW (64x10, 640/640 params)
pool_logit/biases (10, 10/10 params)
```
Note that when you call
`tfprof_logger.write_op_log(...)`, the tool adds all `Variables` inside
`tf.trainable_variables()` to `_trainable_variables`.
12) Run tfprof in one-shot mode and dump result to file.
```shell
# Printed to stdout if --dump_to_file is not set.
tfprof scope --graph_path /cns/ij-d/home/xpan/tfprof/graph.pbtxt \
--max_depth 3 \
--dump_to_file "/tmp/dump"
Reading Files...
Parsing GraphDef...
Preparing Views...
cat /tmp/dump
_TFProfRoot (--/930.58k params)
global_step (0/0 params)
pool_logit/DW (64x10, 640/1.28k params)
pool_logit/biases (10, 10/20 params)
```
13) Analyze how balanced Variable are on parameter servers.
In this tutorial, I'm going to use a seq2seq model, which are split
on several gpus at workers and several parameter servers.
In tfprof, 'device' is an op_type. For example, if op1 and op2 are placed on
gpu0. They share an op_type called 'gpu0'.
```shell
bazel-bin/tensorflow/contrib/tfprof/tools/tfprof/tfprof \
--graph_path ~/tfprof/textsum/graph.pbtxt \
--run_meta_path ~/tfprof/textsum/run_meta
# Looks like ps task 1 is holding twice more parameters than task 0.
tfprof> scope -select device,params -account_type_regexes .*ps.*task:0.* -max_depth 1
_TFProfRoot (--/25.81m params)
tfprof> scope -select device,params -account_type_regexes .*ps.*task:1.* -max_depth 1
_TFProfRoot (--/58.84m params)
```
### CLI Input Files
tfprof command line inference (CLI) loads dumped files from a tensorflow model.
Convert them into in-memory data structures. To use it, users need to specify
the locations of the dumped files. The following are the dumped files loaded
by tfprof:
<b>--graph_path:</b> GraphDef text file (required). Used to build in-memory
representation of the model. For example, graph.pbtxt written by tf.Supervisor
is a candidate. If you are not using tf.Supervisor, you can easily get GraphDef
using tf.Graph.as_graph_def() or other API.
<b>--run_meta_path:</b> tensorflow::RunMetadata.
Used to get the memory and time consumption of
each op of the model. Users need to enable it. For example, the following code
snippet writes a RunMetadata file:
```python
run_options = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
# Once a while, call it the get the RunMeta.
_ = self._sess.run(..., options=run_options, run_metadata=run_metadata)
with gfile.Open(os.path.join(output_dir, "run_meta"), "w") as f:
f.write(run_metadata.SerializeToString())
```
<b>--op_log_path:</b>
tensorflow::tfprof::OpLog. A proto used to provide extra op information
for ops. By giving a group of ops a type name, users can easily aggregate the
statistics for those ops without accidently missing or including extra ops.
tfprof exposes the following Python API to add op information and logging.
```python
def write_op_log(graph, log_dir, op_log=None)
```
<b>--checkpoint_path:</b>
TensorFlow checkpoint. It defines _checkpoint_variable op type. It also
provides checkpointed tensors' values.
## Design
### In-memory representation
<b>Scope:</b> This representation organizes ops based on name scope hierarchy,
similar to filesystem hierarchy. Hence, it is essentially a tree data structure.
For example op1 with name “name1/name2” is a child of op2 with name “name1”.
<b>Graph:</b> The representation organizes ops based on op inputs. Hence it is
a graph structure. The graph is a “directed acyclic graph” (hopefully), with
direction from “output to input”. The direction is design this way so that users
can trace from “result” to its “sources”.
### Command line options
tfprofs major goals are to measure system performance and quicly analyze
model architectures. Hence, its commands and options should allow users to achieve
these 2 goals easily.
<b>graph:</b> It is expected that users will mostly use graph representation to
debug system performance. Hence, tfprof supports graph command, which pulls the
graph in-memory representation described above.
<b>scope:</b> It is expected that some users might want to explore their model
statistics using the name scope information they defined in the Python codes.
Hence, tfprof supports “scope” command, which pulls the tree in-memory
representation.
<b>set:</b> It is used to store the options so that user doesnt need to
re-type the same option again and again in the follow up command line. Note that
tfprof has traditional terminals history and auto-complete support.
<b>help:</b> print help information.
<b>Options:</b> Run “tfprof help” to get detailed explanations.
```python
"-max_depth",
"-min_bytes",
"-min_micros",
"-min_params",
"-min_float_ops",
"-order_by",
"-account_type_regexes",
"-start_name_regexes",
"-trim_name_regexes",
"-show_name_regexes",
"-hide_name_regexes",
"-account_displayed_op_only",
"-select",
"-viz", # Only supported for graph command.
"-dump_to_file",
```
A key design is that stats are aggregated from descendants up to ancestors.
`-account_type_regexes` is used to decide which ops stat is accounted. It makes
decision based on op type. Usually set it to `.*` if no extra type information
is added to the ops using OpLog. Intuitively, only accounted ops are displayed.
`-min/max` and `-show/hide/trim/start` options are only used the optionally
displayed or hide ops based on ops name and stats. However, they dont prevent
tfprof from accounting stats of hidden ops. Hence, the stat of a op can be
aggregated by its parent even if it is hidden. `-account_displayed_op_only` is
an option to break this rule. When it is set, only displayed ops are accounted.
Regexes are all comma-separated, for example `-show_name_regexes`
`regex1.*,regex2.*`. It is designed this way because it is convenient and comma
is not expected to show up in op names.
`-order_by` is used to order displayed ops. Displayed ops at the same hierarchy
(notice the indent printed) are sorted according to order_by.
## Future Work
* Load SummaryWriter event logs so that it can show the latest summary value.
* Better sorting and aggregation of outputs. Easier comprehension.
* Currently, shape information is based on `graph.pbtxt`. When the shape
information is incomplete, tfprof ignores it. See if it can use `RunMetadata`
and `Checkpoint` to complete shape information.

View File

@ -0,0 +1,31 @@
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "tfprof_logger",
srcs = ["tfprof_logger.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,114 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logging tensorflow::tfprof::OpLog.
OpLog is used to add extra model information for offline analysis by tfprof.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_log_pb2
from tensorflow.python.framework import ops
TRAINABLE_VARIABLES = '_trainable_variables'
REGISTERED_FLOP_STATS = 'flops'
def _get_logged_ops(graph):
"""Extract trainable model parameters and FLOPs for ops from a Graph.
Args:
graph: tf.Graph.
Returns:
logged_ops: dict mapping from op_name to OpLogEntry.
"""
logged_ops = {}
graph_def = graph.as_graph_def()
for node in graph_def.node:
try:
stats = ops.get_stats_for_node_def(graph, node, REGISTERED_FLOP_STATS)
except ValueError:
# Catch Exception When shape is incomplete. Skip it.
stats = None
if not stats or not stats.value:
continue
if node.name not in logged_ops:
entry = tfprof_log_pb2.OpLogEntry()
entry.name = node.name
entry.float_ops = stats.value
logged_ops[entry.name] = entry
for v in graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
if v.op.name not in logged_ops:
entry = tfprof_log_pb2.OpLogEntry()
entry.name = v.op.name
entry.types.append(TRAINABLE_VARIABLES)
logged_ops[entry.name] = entry
else:
logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
return logged_ops
def _merge_default_with_oplog(graph, op_log=None):
"""Merge the tfprof default extra info with caller's op_log.
Args:
graph: tf.Graph.
op_log: OpLog proto.
Returns:
tmp_op_log: Merged OpLog proto.
"""
tmp_op_log = tfprof_log_pb2.OpLog()
logged_ops = _get_logged_ops(graph)
if not op_log:
tmp_op_log.log_entries.extend(logged_ops.values())
else:
all_ops = dict()
for entry in op_log.log_entries:
all_ops[entry.name] = entry
for op_name, entry in logged_ops.iteritems():
if op_name in all_ops:
all_ops[op_name].types.extend(entry.types)
if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
all_ops[op_name].float_ops = entry.float_ops
else:
all_ops[op_name] = entry
tmp_op_log.log_entries.extend(all_ops.values())
return tmp_op_log
def write_op_log(graph, log_dir, op_log=None):
"""Log provided 'op_log', and add additional model information below.
The API also assigns ops in tf.trainable_variables() an op type called
'_trainable_variables'.
The API also logs 'flops' statistics for ops with op.RegisterStatistics()
defined.
Args:
graph: tf.Graph.
log_dir: directory to write the log file.
op_log: OpLog proto.
"""
op_log = _merge_default_with_oplog(graph, op_log)
with tf.gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
log.write(op_log.SerializeToString())

View File

@ -0,0 +1,52 @@
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
cc_binary(
name = "tfprof",
srcs = ["tfprof_main.cc"],
deps = [
":protos_all_cc",
"//tensorflow/c:c_api",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof/internal:tfprof_options",
"//tensorflow/contrib/tfprof/tools/tfprof/internal:tfprof_stats",
"//tensorflow/contrib/tfprof/tools/tfprof/internal:tfprof_utils",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@linenoise//:linenoise",
],
)
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
tf_proto_library(
name = "protos_all",
srcs = glob(
["**/*.proto"],
),
cc_api_version = 2,
cc_libs = ["//tensorflow/core:protos_all_cc"],
go_api_version = 2,
java_api_version = 2,
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,227 @@
package(
default_visibility = ["//tensorflow:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "tfprof_stats",
srcs = ["tfprof_stats.cc"],
hdrs = ["tfprof_stats.h"],
deps = [
":tfprof_graph",
":tfprof_node",
":tfprof_options",
":tfprof_scope",
":tfprof_show",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tfprof_node",
srcs = ["tfprof_node.cc"],
hdrs = ["tfprof_node.h"],
deps = [
":tfprof_options",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tfprof_scope",
srcs = ["tfprof_scope.cc"],
hdrs = ["tfprof_scope.h"],
deps = [
":tfprof_constants",
":tfprof_node",
":tfprof_options",
":tfprof_show",
":tfprof_tensor",
":tfprof_utils",
"//tensorflow/c:c_api",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tfprof_graph",
srcs = ["tfprof_graph.cc"],
hdrs = ["tfprof_graph.h"],
deps = [
":tfprof_constants",
":tfprof_node",
":tfprof_options",
":tfprof_show",
":tfprof_tensor",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tfprof_show",
srcs = ["tfprof_show.cc"],
hdrs = ["tfprof_show.h"],
deps = [
":tfprof_constants",
":tfprof_node",
":tfprof_options",
":tfprof_tensor",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "tfprof_show_test",
srcs = ["tfprof_show_test.cc"],
data = [
"testdata/ckpt",
"testdata/graph.pbtxt",
"testdata/run_meta",
"testdata/tfprof_log",
],
deps = [
":tfprof_constants",
":tfprof_options",
":tfprof_stats",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "tfprof_utils",
srcs = ["tfprof_utils.cc"],
hdrs = ["tfprof_utils.h"],
deps = [
":tfprof_options",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "tfprof_options",
srcs = ["tfprof_options.cc"],
hdrs = ["tfprof_options.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib",
],
)
cc_library(
name = "print_model_analysis",
srcs = ["print_model_analysis.cc"],
hdrs = ["print_model_analysis.h"],
deps = [
":tfprof_options",
":tfprof_stats",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "tfprof_stats_test",
srcs = ["tfprof_stats_test.cc"],
data = [
"testdata/ckpt",
"testdata/graph.pbtxt",
"testdata/run_meta",
"testdata/tfprof_log",
],
deps = [
":tfprof_constants",
":tfprof_options",
":tfprof_stats",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "tfprof_tensor",
srcs = ["tfprof_tensor.cc"],
hdrs = ["tfprof_tensor.h"],
deps = [
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "tfprof_tensor_test",
srcs = ["tfprof_tensor_test.cc"],
data = [
"testdata/ckpt",
"testdata/graph.pbtxt",
],
deps = [
":tfprof_options",
":tfprof_stats",
":tfprof_utils",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/contrib/tfprof/tools/tfprof:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "tfprof_constants",
hdrs = ["tfprof_constants.h"],
deps = [
],
)
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,65 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/print_model_analysis.h"
#include <stdio.h>
#include <memory>
#include <utility>
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_stats.h"
namespace tensorflow {
namespace tfprof {
string PrintModelAnalysis(const string* graph, const string* run_meta,
const string* op_log, const string* command,
const Options* options) {
CHECK(graph) << "graph mustn't be null";
CHECK(command) << "command mustn't be null";
CHECK(options) << "options mustn't be null";
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
graph_ptr->ParseFromString(*graph);
std::unique_ptr<RunMetadata> run_meta_ptr;
if (run_meta) {
run_meta_ptr.reset(new RunMetadata());
run_meta_ptr->ParseFromString(*run_meta);
}
std::unique_ptr<OpLog> op_log_ptr;
if (op_log) {
op_log_ptr.reset(new OpLog());
op_log_ptr->ParseFromString(*op_log);
}
std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
TFStats tf_stats(std::move(graph_ptr), std::move(run_meta_ptr),
std::move(op_log_ptr), std::move(ckpt_reader));
if (options->dump_to_file.empty()) {
printf("\n=========================Options=============================\n");
printf("%s", options->ToString().c_str());
printf("\n==================Model Analysis Report======================\n");
TFProfNode root(tf_stats.PrintGraph(*command, *options));
printf("\n======================End of Report==========================\n");
fflush(stdout);
return root.SerializeAsString();
}
return tf_stats.PrintGraph(*command, *options).SerializeAsString();
}
} // namespace tfprof
} // namespace tensorflow

View File

@ -0,0 +1,45 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_PRINT_MODEL_ANALYSIS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_PRINT_MODEL_ANALYSIS_H_
#include <string>
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/tfprof_log.pb.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/tfprof_output.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
namespace tfprof {
// ***This API is only for swig.***
//
// Interface defined for Python API swig. Calls the tfprof core API.
// 'graph', 'run_meta', 'op_log' are serialized GraphDef, RunMetadata,
// OpLog strings, respectively.
// 'graph', 'command' and 'options' are required. Others can be nullptr
// if not available.
string PrintModelAnalysis(const string* graph, const string* run_meta,
const string* op_log, const string* command,
const Options* options);
} // namespace tfprof
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_PRINT_MODEL_ANALYSIS_H_

Binary file not shown.

View File

@ -0,0 +1,636 @@
node {
name: "zeros"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 2
}
dim {
size: 6
}
dim {
size: 6
}
dim {
size: 3
}
}
float_val: 0.0
}
}
}
}
node {
name: "DW"
op: "Variable"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
dim {
size: 3
}
dim {
size: 3
}
dim {
size: 6
}
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "DW/Initializer/random_normal/shape"
op: "Const"
attr {
key: "_class"
value {
list {
s: "loc:@DW"
}
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 4
}
}
tensor_content: "\003\000\000\000\003\000\000\000\003\000\000\000\006\000\000\000"
}
}
}
}
node {
name: "DW/Initializer/random_normal/mean"
op: "Const"
attr {
key: "_class"
value {
list {
s: "loc:@DW"
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0
}
}
}
}
node {
name: "DW/Initializer/random_normal/stddev"
op: "Const"
attr {
key: "_class"
value {
list {
s: "loc:@DW"
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0010000000475
}
}
}
}
node {
name: "DW/Initializer/random_normal/RandomStandardNormal"
op: "RandomStandardNormal"
input: "DW/Initializer/random_normal/shape"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW"
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "seed"
value {
i: 87654321
}
}
attr {
key: "seed2"
value {
i: 5
}
}
}
node {
name: "DW/Initializer/random_normal/mul"
op: "Mul"
input: "DW/Initializer/random_normal/RandomStandardNormal"
input: "DW/Initializer/random_normal/stddev"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW"
}
}
}
}
node {
name: "DW/Initializer/random_normal"
op: "Add"
input: "DW/Initializer/random_normal/mul"
input: "DW/Initializer/random_normal/mean"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW"
}
}
}
}
node {
name: "DW/Assign"
op: "Assign"
input: "DW"
input: "DW/Initializer/random_normal"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW"
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "DW/read"
op: "Identity"
input: "DW"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW"
}
}
}
}
node {
name: "Conv2D"
op: "Conv2D"
input: "zeros"
input: "DW/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "data_format"
value {
s: "NHWC"
}
}
attr {
key: "padding"
value {
s: "SAME"
}
}
attr {
key: "strides"
value {
list {
i: 1
i: 2
i: 2
i: 1
}
}
}
attr {
key: "use_cudnn_on_gpu"
value {
b: true
}
}
}
node {
name: "DW2"
op: "Variable"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 2
}
dim {
size: 6
}
dim {
size: 12
}
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "DW2/Initializer/random_normal/shape"
op: "Const"
attr {
key: "_class"
value {
list {
s: "loc:@DW2"
}
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 4
}
}
tensor_content: "\002\000\000\000\002\000\000\000\006\000\000\000\014\000\000\000"
}
}
}
}
node {
name: "DW2/Initializer/random_normal/mean"
op: "Const"
attr {
key: "_class"
value {
list {
s: "loc:@DW2"
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0
}
}
}
}
node {
name: "DW2/Initializer/random_normal/stddev"
op: "Const"
attr {
key: "_class"
value {
list {
s: "loc:@DW2"
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 0.0010000000475
}
}
}
}
node {
name: "DW2/Initializer/random_normal/RandomStandardNormal"
op: "RandomStandardNormal"
input: "DW2/Initializer/random_normal/shape"
attr {
key: "T"
value {
type: DT_INT32
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW2"
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "seed"
value {
i: 87654321
}
}
attr {
key: "seed2"
value {
i: 15
}
}
}
node {
name: "DW2/Initializer/random_normal/mul"
op: "Mul"
input: "DW2/Initializer/random_normal/RandomStandardNormal"
input: "DW2/Initializer/random_normal/stddev"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW2"
}
}
}
}
node {
name: "DW2/Initializer/random_normal"
op: "Add"
input: "DW2/Initializer/random_normal/mul"
input: "DW2/Initializer/random_normal/mean"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW2"
}
}
}
}
node {
name: "DW2/Assign"
op: "Assign"
input: "DW2"
input: "DW2/Initializer/random_normal"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW2"
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "DW2/read"
op: "Identity"
input: "DW2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@DW2"
}
}
}
}
node {
name: "Conv2D_1"
op: "Conv2D"
input: "Conv2D"
input: "DW2/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "data_format"
value {
s: "NHWC"
}
}
attr {
key: "padding"
value {
s: "SAME"
}
}
attr {
key: "strides"
value {
list {
i: 1
i: 2
i: 2
i: 1
}
}
}
attr {
key: "use_cudnn_on_gpu"
value {
b: true
}
}
}
versions {
producer: 13
}

View File

@ -0,0 +1,22 @@
<EFBFBD>
<EFBFBD>
%/job:localhost/replica:0/task:0/cpu:0:
_SOURCEû¡ˆ§·†Ï (2
cpuB_SOURCE = NoOp()H塈§·†Ïa
zeros”¢ˆ§·†Ï (2
cpu:(&"àcpu0€ýèÉöûBzeros = Const()H<>¢ˆ§·†Ï^
DW<10>¢ˆ§·†Ï (2
cpu:(&"ˆcpu0à©€ ìûBDW = Variable()H¢ˆ§·†Ï`
DW2Ÿ¢ˆ§·†Ï (2
cpu:(& "cpu0 Ÿ€àëûBDW2 = Variable()H¢ˆ§·†Ïj
DW/read±¢ˆ§·†Ï (2
cpu:(&"ˆcpu0à©€ ìûBDW/read = Identity(DW)H¥¢ˆ§·†Ïm
DW2/read¸¢ˆ§·†Ï (2
cpu:(& "cpu0 Ÿ€àëûBDW2/read = Identity(DW2)H§¢ˆ§·†Ïs
Conv2D¹¢ˆ§·†Ï P(U2
cpu°:(&"°cpu0à«€àìûBConv2D = Conv2D(zeros, DW/read)H¶¢ˆ§·†Ï{
Conv2D_1’£ˆ§·†Ï (2
cpu:(& "cpu0฀àìûB#Conv2D_1 = Conv2D(Conv2D, DW2/read)HŽ£ˆ§·†Ï6
_SINK³£ˆ§·†Ï (2
cpuB_SINK = NoOp()H­£ˆ§·†Ï

View File

@ -0,0 +1,9 @@
Conv2D_1€$

DW2_trainable_variables

DW_trainable_variables
Conv2DÈ-

View File

@ -0,0 +1,37 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_CONSTANTS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_CONSTANTS_H_
namespace tensorflow {
namespace tfprof {
// Op name of root of everything. Aggregates all stats.
static const char* const kTFProfRoot = "_TFProfRoot";
// Op type for nodes that doesn't represent a physical node in the
// TensorFlow model. Only exist as a placehold to aggregate children.
// For example, kTFProfRoot belongs to this type.
static const char* const kTFGraphParent = "_TFGraphParent";
static const char* const kTFScopeParent = "_kTFScopeParent";
// Op type for tf.trainable_variables().
static const char* const kTrainableVarType = "_trainable_variables";
// Op type for tensors in the checkpoint file.
static const char* const kCkptVarType = "_checkpoint_variables";
} // namespace tfprof
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_CONSTANTS_H_

View File

@ -0,0 +1,222 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_graph.h"
#include <stdio.h>
#include <utility>
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_constants.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_tensor.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
namespace tfprof {
GraphNode* TFGraph::CreateParentNode(const string& name) {
node_defs_.push_back(std::unique_ptr<NodeDef>(new NodeDef()));
node_defs_.back()->set_name(name);
node_defs_.back()->set_op(kTFGraphParent);
parent_nodes_[name] =
std::unique_ptr<TFNode>(new TFNode(node_defs_.back().get()));
nodes_map_[name] =
std::unique_ptr<GraphNode>(new GraphNode(parent_nodes_[name].get()));
return nodes_map_[name].get();
}
void TFGraph::AddNode(TFNode* node) {
string name = node->node_def()->name();
nodes_map_[name] = std::unique_ptr<GraphNode>(new GraphNode(node));
}
void TFGraph::Build() {
if (!roots_.empty()) return;
std::set<string> nonroots;
// Filter out the root nodes (node not input of any other node).
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
GraphNode* node = it->second.get();
const std::map<string, TFNode*>& inputs = node->node->inputs();
for (auto inputs_it = inputs.cbegin(); inputs_it != inputs.cend();
inputs_it++) {
nonroots.insert(inputs_it->first);
auto child_it = nodes_map_.find(inputs_it->first);
if (child_it != nodes_map_.end()) {
node->children.push_back(child_it->second.get());
}
}
}
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
if (nonroots.find(it->first) == nonroots.end()) {
roots_.push_back(it->second.get());
}
}
}
const ShowNode* TFGraph::ShowInternal(const Options& opts) {
// Search the nodes to start from.
std::vector<GraphNode*> roots = roots_;
if (opts.start_name_regexes.size() != 1 ||
opts.start_name_regexes[0] != ".*") {
std::set<string> visited;
roots = SearchRoot(roots, opts.start_name_regexes, &visited);
}
GraphNode* root = CreateParentNode(kTFProfRoot);
root->children.assign(roots.begin(), roots.end());
std::map<string, int64> account_visits;
Account({root}, opts, &account_visits);
if (opts.viz) {
printf("Visualizing feature disabled...\n");
}
std::set<string> visits;
return PrintGraph({root}, opts, 1, 0, 0, &visits)[0];
}
std::vector<GraphNode*> TFGraph::SearchRoot(
const std::vector<GraphNode*>& roots, const std::vector<string>& regexes,
std::set<string>* visited) {
std::vector<GraphNode*> res;
if (roots.empty()) {
return res;
}
for (GraphNode* root : roots) {
if (visited->find(root->name()) != visited->end()) continue;
visited->insert(root->name());
// If the parent is a start point, don't search its children.
// Note that its children can still be added as start node through
// another route.
bool match_start_node = false;
for (const string& regex : regexes) {
if (RE2::FullMatch(root->name(), regex)) {
res.push_back(root);
match_start_node = true;
break;
}
}
if (match_start_node) {
continue;
}
std::vector<GraphNode*> nroot =
SearchRoot(root->children, regexes, visited);
res.insert(res.end(), nroot.begin(), nroot.end());
}
return res;
}
std::vector<GraphNode*> TFGraph::PrintGraph(const std::vector<GraphNode*> roots,
const Options& opts, int depth,
int hidden, int last_ident,
std::set<string>* visits) {
std::vector<GraphNode*> show_nodes;
for (GraphNode* node : roots) {
if (visits->find(node->name()) != visits->end()) continue;
visits->insert(node->name());
int nhidden = hidden;
int nlast_ident = last_ident;
bool show = ShouldShow(node, opts, depth);
if (show) {
node->formatted_str.clear();
if (opts.account_displayed_op_only) {
node->ResetTotalStats();
node->AddSelfToTotalStats();
}
nhidden = 0;
nlast_ident = (hidden && opts.select.find(kShown[4]) != opts.select.end()
? last_ident + 4
: last_ident + 2);
} else {
++nhidden;
}
std::vector<GraphNode*> show_cnodes;
if (!ShouldTrim(node, opts.trim_name_regexes)) {
show_cnodes = PrintGraph(node->children, opts, depth + 1, nhidden,
nlast_ident, visits);
}
if (show) {
show_cnodes = SortNodes(show_cnodes, opts);
string children_str;
for (GraphNode* sc : show_cnodes) {
children_str += sc->formatted_str;
node->mutable_proto()->add_children()->MergeFrom(sc->proto());
if (opts.account_displayed_op_only) {
node->AggregateTotalStats(sc);
}
}
if (hidden && opts.select.find(kShown[4]) != opts.select.end()) {
node->formatted_str = strings::Printf(
"%s...hidden %d...\n", string(last_ident, ' ').c_str(), hidden);
node->formatted_str +=
strings::Printf(" %s%s\n", string(last_ident, ' ').c_str(),
node->Format(opts).c_str());
} else {
node->formatted_str =
strings::Printf("%s%s\n", string(last_ident, ' ').c_str(),
node->Format(opts).c_str());
}
if (opts.select.find(kShown[5]) != opts.select.end()) {
std::unique_ptr<TFProfTensor> tfprof_tensor;
if (LookUpCheckPoint(node->name(), &tfprof_tensor)) {
string value_str;
tfprof_tensor->Display(&value_str,
node->mutable_proto()->mutable_tensor_value());
node->formatted_str += value_str;
}
}
node->formatted_str += children_str;
show_nodes.push_back(node);
} else {
show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
show_cnodes.end());
}
}
return show_nodes;
}
void TFGraph::Account(const std::vector<GraphNode*>& roots, const Options& opts,
std::map<string, int64>* visits) {
if (roots.empty()) return;
for (GraphNode* node : roots) {
if (visits->find(node->name()) != visits->end()) continue;
(*visits)[node->name()] = 1;
node->ResetTotalStats();
// Depth-firsth.
Account(node->children, opts, visits);
node->account = ShouldAccount(node, opts);
if (node->account) {
node->AddSelfToTotalStats();
}
// Aggregate its children stats.
for (GraphNode* c : node->children) {
// A node can be visited from multiple parents. Only account once.
// "visits==1" is when the node is visited through depth-first search.
(*visits)[c->name()] += 1;
if ((*visits)[c->name()] > 2) continue;
node->AggregateTotalStats(c);
}
}
}
} // namespace tfprof
} // namespace tensorflow

View File

@ -0,0 +1,116 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Build a graph structure based on op inputs/outputs. The graph is a directed
// acyclic graph pointing *from outputs to inputs*.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_GRAPH_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_GRAPH_H_
#include <deque>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_node.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_show.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_utils.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/tfprof_output.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace tfprof {
class GraphNode : public ShowNode {
public:
explicit GraphNode(TFNode* node) : ShowNode(node) {
mutable_proto()->set_inputs(node->inputs().size());
mutable_proto()->set_total_inputs(0);
}
void AggregateTotalStats(GraphNode* node) {
ShowNode::AggregateTotalStats(node);
mutable_proto()->set_total_inputs(proto().total_inputs() +
node->proto().total_inputs() + 1);
}
void AddSelfToTotalStats() {
ShowNode::AddSelfToTotalStats();
mutable_proto()->set_total_inputs(proto().total_inputs() +
proto().inputs());
}
void ResetTotalStats() {
ShowNode::ResetTotalStats();
mutable_proto()->set_total_inputs(0);
}
std::vector<GraphNode*> children;
};
// Organize tensorflow ops in a graph structure, pointing from output ops
// to input ops.
class TFGraph : public TFShow {
public:
explicit TFGraph(checkpoint::CheckpointReader* ckpt_reader)
: TFShow(ckpt_reader) {}
~TFGraph() override {}
void AddNode(TFNode* node) override;
void Build() override;
private:
const ShowNode* ShowInternal(const Options& opts) override;
bool ShouldShowIfExtra(ShowNode* node, const Options& opts,
int depth) override {
return true;
}
GraphNode* CreateParentNode(const string& name);
std::vector<GraphNode*> SearchRoot(const std::vector<GraphNode*>& roots,
const std::vector<string>& regexes,
std::set<string>* visited);
std::vector<GraphNode*> PrintGraph(const std::vector<GraphNode*> roots,
const Options& opts, int depth, int hidden,
int last_ident, std::set<string>* visits);
void VisualizeGraph(GraphNode* root, const Options& opts);
std::vector<GraphNode*> GenerateGraphDot(
GraphNode* root, GraphNode* last_shown, const Options& opts, int depth,
int hidden, std::set<string>* declared_nodes,
std::set<string>* declared_edges, TFProfNode* parent);
void Account(const std::vector<GraphNode*>& roots, const Options& opts,
std::map<string, int64>* visits);
std::vector<GraphNode*> roots_;
std::vector<std::unique_ptr<NodeDef>> node_defs_;
std::map<string, std::unique_ptr<TFNode>> parent_nodes_;
std::map<string, std::unique_ptr<GraphNode>> nodes_map_;
};
} // namespace tfprof
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_GRAPH_H_

View File

@ -0,0 +1,47 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_node.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
namespace tensorflow {
namespace tfprof {
void TFNode::AddStepStat(const string& device, const NodeExecStats* step_stat) {
if (!device.empty()) {
// This might override device from GraphDef.
device_ = device;
}
step_stat_ = step_stat;
op_start_micros_ = step_stat_->all_start_micros();
if (step_stat_->op_end_rel_micros() && step_stat_->op_start_rel_micros()) {
op_exec_micros_ =
step_stat_->op_end_rel_micros() - step_stat_->op_start_rel_micros();
}
all_spent_micros_ = step_stat_->all_end_rel_micros();
for (const auto& output : step_stat_->output()) {
if (output.has_tensor_description() &&
output.tensor_description().has_allocation_description()) {
requested_bytes_ += output.tensor_description()
.allocation_description()
.requested_bytes();
}
}
}
} // namespace tfprof
} // namespace tensorflow

View File

@ -0,0 +1,106 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_H_
#include <map>
#include <set>
#include <string>
#include <vector>
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace tfprof {
class TFNode {
public:
TFNode(const NodeDef* node)
: node_(node),
step_stat_(nullptr),
op_start_micros_(0),
op_exec_micros_(0),
all_spent_micros_(0),
requested_bytes_(0),
float_ops_(0) {
if (!node) return;
for (const auto& attr : node->attr()) {
// TODO(xpan): Also consider _output_shapes.
if (attr.first != "shape" || !attr.second.has_shape()) continue;
if (!shape_.empty()) {
fprintf(stderr, "Found duplicated shapes!\n");
continue;
}
std::vector<int64> shape_vec;
for (const auto& d : attr.second.shape().dim()) {
shape_vec.push_back(d.size());
}
update_shape(shape_vec);
}
op_types_.insert(node->op());
device_ = node->device();
}
TFNode() : TFNode(nullptr) {}
void AddInput(TFNode* input) { inputs_[input->node_def()->name()] = input; }
void AddOpType(const string& op_type) { op_types_.insert(op_type); }
void AddStepStat(const string& device, const NodeExecStats* step_stat);
void AddFloatOps(int64 float_ops) { float_ops_ = float_ops; }
const NodeDef* node_def() { return node_; }
const std::map<string, TFNode*>& inputs() { return inputs_; }
int64 op_start_micros() { return op_start_micros_; }
int64 op_exec_micros() { return op_exec_micros_; }
int64 all_spent_micros() { return all_spent_micros_; }
int64 requested_byptes() { return requested_bytes_; }
int64 float_ops() { return float_ops_; }
string device() { return device_; }
const std::set<string>& op_types() { return op_types_; }
const std::vector<int64>& shape() { return shape_; }
void update_shape(const std::vector<int64>& shape) { shape_ = shape; }
private:
std::map<string, TFNode*> inputs_;
const NodeDef* node_;
const NodeExecStats* step_stat_;
std::vector<int64> shape_;
std::set<string> op_types_;
string device_;
int64 op_start_micros_;
int64 op_exec_micros_;
int64 all_spent_micros_;
int64 requested_bytes_;
int64 float_ops_;
};
} // namespace tfprof
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_NODE_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_options.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
namespace tfprof {
string Options::ToString() const {
const string s = strings::Printf(
"%-28s%d\n"
"%-28s%lld\n"
"%-28s%lld\n"
"%-28s%lld\n"
"%-28s%lld\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n"
"%-28s%s\n",
kOptions[0], max_depth, kOptions[1], min_bytes, kOptions[2], min_micros,
kOptions[3], min_params, kOptions[4], min_float_ops, kOptions[5],
str_util::Join(device_regexes, ",").c_str(), kOptions[6],
order_by.c_str(), kOptions[7],
str_util::Join(account_type_regexes, ",").c_str(), kOptions[8],
str_util::Join(start_name_regexes, ",").c_str(), kOptions[9],
str_util::Join(trim_name_regexes, ",").c_str(), kOptions[10],
str_util::Join(show_name_regexes, ",").c_str(), kOptions[11],
str_util::Join(hide_name_regexes, ",").c_str(), kOptions[12],
(account_displayed_op_only ? "true" : "false"), kOptions[13],
str_util::Join(select, ",").c_str(), kOptions[14],
(viz ? "true" : "false"), kOptions[15], dump_to_file.c_str());
return s;
}
} // namespace tfprof
} // namespace tensorflow

View File

@ -0,0 +1,119 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_OPTIONS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_OPTIONS_H_
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace tfprof {
static const char* const kOptions[] = {
"-max_depth",
"-min_bytes",
"-min_micros",
"-min_params",
"-min_float_ops",
"-device_regexes",
"-order_by",
"-account_type_regexes",
"-start_name_regexes",
"-trim_name_regexes",
"-show_name_regexes",
"-hide_name_regexes",
"-account_displayed_op_only",
"-select",
"-viz",
"-dump_to_file",
};
static const char* const kOrderBy[] = {
"name", "bytes", "micros", "params", "float_ops",
};
// Append Only.
static const char* const kShown[] = {
"bytes", "micros", "params", "float_ops",
"num_hidden_ops", "tensor_value", "device", "op_types",
};
static const char* const kCmds[] = {
"scope", "graph", "set", "help",
};
struct Options {
public:
virtual ~Options() {}
Options(int max_depth, tensorflow::int64 min_bytes,
tensorflow::int64 min_micros, tensorflow::int64 min_params,
tensorflow::int64 min_float_ops,
const std::vector<string>& device_regexes, const string& order_by,
const std::vector<string>& account_type_regexes,
const std::vector<string>& start_name_regexes,
const std::vector<string>& trim_name_regexes,
const std::vector<string>& show_name_regexes,
const std::vector<string>& hide_name_regexes,
bool account_displayed_op_only, const std::vector<string>& select,
bool viz, const string& dump_to_file = "")
: max_depth(max_depth),
min_bytes(min_bytes),
min_micros(min_micros),
min_params(min_params),
min_float_ops(min_float_ops),
device_regexes(device_regexes),
order_by(order_by),
account_type_regexes(account_type_regexes),
start_name_regexes(start_name_regexes),
trim_name_regexes(trim_name_regexes),
show_name_regexes(show_name_regexes),
hide_name_regexes(hide_name_regexes),
account_displayed_op_only(account_displayed_op_only),
select(select.begin(), select.end()),
viz(viz),
dump_to_file(dump_to_file) {}
string ToString() const;
int max_depth;
tensorflow::int64 min_bytes;
tensorflow::int64 min_micros;
tensorflow::int64 min_params;
tensorflow::int64 min_float_ops;
std::vector<string> device_regexes;
string order_by;
std::vector<string> account_type_regexes;
std::vector<string> start_name_regexes;
std::vector<string> trim_name_regexes;
std::vector<string> show_name_regexes;
std::vector<string> hide_name_regexes;
bool account_displayed_op_only;
std::set<string> select;
bool viz;
string dump_to_file;
};
} // namespace tfprof
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TFPROF_TOOLS_TFPROF_INTERNAL_TFPROF_OPTIONS_H_

View File

@ -0,0 +1,191 @@
/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_scope.h"
#include <stdio.h>
#include <utility>
#include "tensorflow/c/c_api.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_constants.h"
#include "tensorflow/contrib/tfprof/tools/tfprof/internal/tfprof_tensor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
namespace tfprof {
ScopeNode* TFScope::CreateParentNode(const string& name) {
if (nodes_map_.find(name) != nodes_map_.end()) {
return nodes_map_[name].get();
}
node_defs_.push_back(std::unique_ptr<NodeDef>(new NodeDef()));
node_defs_.back()->set_name(name);
node_defs_.back()->set_op(kTFScopeParent);
parent_nodes_[name] =
std::unique_ptr<TFNode>(new TFNode(node_defs_.back().get()));
nodes_map_[name] =
std::unique_ptr<ScopeNode>(new ScopeNode(parent_nodes_[name].get()));
return nodes_map_[name].get();
}
void TFScope::AddNode(TFNode* node) {
string name = node->node_def()->name();
if (nodes_map_.find(node->node_def()->name()) == nodes_map_.end()) {
nodes_map_[name] = std::unique_ptr<ScopeNode>(new ScopeNode(node));
}
auto last_slash = name.find_last_of("/");
while (last_slash != name.npos) {
name = name.substr(0, last_slash);
if (nodes_map_.find(name) == nodes_map_.end()) {
CHECK(CreateParentNode(name));
}
last_slash = name.find_last_of("/");
}
}
void TFScope::Build() {
if (!roots_.empty()) return;
// Found roots, which are nodes without "/".
for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
ScopeNode* node = it->second.get();
auto last_slash = node->name().find_last_of("/");
if (last_slash == string::npos) {
roots_.push_back(node);
} else {
const string prefix = node->name().substr(0, last_slash);
nodes_map_[prefix]->children.push_back(node);
}
}
}
const ShowNode* TFScope::ShowInternal(const Options& opts) {
// Search from roots recursively to find start node, if start_name_regexes
// is specified.
std::vector<ScopeNode*> roots = roots_;
if (opts.start_name_regexes.size() != 1 ||
opts.start_name_regexes[0] != ".*") {
roots = SearchRoot(roots, opts.start_name_regexes);
}
ScopeNode* root = CreateParentNode(kTFProfRoot);
root->children.assign(roots.begin(), roots.end());
Account({root}, opts);
root = PrintScope({root}, opts, 1, 0)[0];
return root;
}
std::vector<ScopeNode*> TFScope::SearchRoot(
std::vector<ScopeNode*> roots, const std::vector<string>& regexes) {
std::vector<ScopeNode*> res;
if (roots.empty()) {
return res;
}
for (ScopeNode* root : roots) {
bool match_start_node = false;
for (const string& regex : regexes) {
if (RE2::FullMatch(root->name(), regex)) {
res.push_back(root);
match_start_node = true;
break;
}
}
if (match_start_node) {
// Found a start node at this branch, no need to continue.
continue;
}
std::vector<ScopeNode*> nroots = SearchRoot(root->children, regexes);
res.insert(res.end(), nroots.begin(), nroots.end());
}
return res;
}
std::vector<ScopeNode*> TFScope::PrintScope(const std::vector<ScopeNode*> roots,
const Options& opts, int depth,
int last_ident) {
std::vector<ScopeNode*> show_nodes;
for (ScopeNode* node : roots) {
int nlast_ident = last_ident;
bool show = ShouldShow(node, opts, depth);
if (show) {
node->formatted_str.clear();
if (opts.account_displayed_op_only) {
node->ResetTotalStats();
node->AddSelfToTotalStats();
}
nlast_ident += 2;
}
std::vector<ScopeNode*> show_cnodes;
if (!ShouldTrim(node, opts.trim_name_regexes)) {
show_cnodes = PrintScope(node->children, opts, depth + 1, nlast_ident);
}
if (show) {
show_cnodes = SortNodes(show_cnodes, opts);
string children_str;
for (ScopeNode* sc : show_cnodes) {
children_str += sc->formatted_str;
node->mutable_proto()->add_children()->MergeFrom(sc->proto());
if (opts.account_displayed_op_only) {
node->AggregateTotalStats(sc);
}
}
node->formatted_str =
strings::Printf("%s%s\n", string(last_ident, ' ').c_str(),
node->Format(opts).c_str());
if (opts.select.find(kShown[5]) != opts.select.end()) {
std::unique_ptr<TFProfTensor> tfprof_tensor;
if (LookUpCheckPoint(node->name(), &tfprof_tensor)) {
string value_str;
tfprof_tensor->Display(&value_str,
node->mutable_proto()->mutable_tensor_value());
node->formatted_str += value_str;
}
}
node->formatted_str += children_str;
show_nodes.push_back(node);
} else {
show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
show_cnodes.end());
}
}
return show_nodes;
}
void TFScope::Account(const std::vector<ScopeNode*>& roots,
const Options& opts) {
if (roots.empty()) return;
for (ScopeNode* node : roots) {
node->ResetTotalStats();
Account(node->children, opts);
node->account = ShouldAccount(node, opts);
if (node->account) {
node->AddSelfToTotalStats();
}
for (ScopeNode* c : node->children) {
node->AggregateTotalStats(c);
}
}
}
} // namespace tfprof
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More