Merge changes from github.

PiperOrigin-RevId: 202585094
This commit is contained in:
Mingxing Tan 2018-06-28 19:13:20 -07:00 committed by TensorFlower Gardener
parent 3cee10e61c
commit 1e7b0e4ad6
145 changed files with 6298 additions and 1705 deletions

1
.gitignore vendored
View File

@ -16,6 +16,7 @@ __pycache__
cmake_build/ cmake_build/
.idea/** .idea/**
/build/ /build/
[Bb]uild/
/tensorflow/core/util/version_info.cc /tensorflow/core/util/version_info.cc
/tensorflow/python/framework/fast_tensor_util.cpp /tensorflow/python/framework/fast_tensor_util.cpp
Pods Pods

View File

@ -943,6 +943,35 @@ def set_tf_cudnn_version(environ_cp):
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
"""Check compatibility between given library and cudnn/cudart libraries."""
ldd_bin = which('ldd') or '/usr/bin/ldd'
ldd_out = run_shell([ldd_bin, lib], True)
ldd_out = ldd_out.split(os.linesep)
cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
cudnn = None
cudart = None
cudnn_ok = True # assume no cudnn dependency by default
cuda_ok = True # assume no cuda dependency by default
for line in ldd_out:
if 'libcudnn.so' in line:
cudnn = cudnn_pattern.search(line)
cudnn_ok = False
elif 'libcudart.so' in line:
cudart = cuda_pattern.search(line)
cuda_ok = False
if cudnn and len(cudnn.group(1)):
cudnn = convert_version_to_int(cudnn.group(1))
if cudart and len(cudart.group(1)):
cudart = convert_version_to_int(cudart.group(1))
if cudnn is not None:
cudnn_ok = (cudnn == cudnn_ver)
if cudart is not None:
cuda_ok = (cudart == cuda_ver)
return cudnn_ok and cuda_ok
def set_tf_tensorrt_install_path(environ_cp): def set_tf_tensorrt_install_path(environ_cp):
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
@ -959,8 +988,8 @@ def set_tf_tensorrt_install_path(environ_cp):
raise ValueError('Currently TensorRT is only supported on Linux platform.') raise ValueError('Currently TensorRT is only supported on Linux platform.')
# Ask user whether to add TensorRT support. # Ask user whether to add TensorRT support.
if str(int(get_var( if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT',
environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': False))) != '1':
return return
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
@ -973,47 +1002,29 @@ def set_tf_tensorrt_install_path(environ_cp):
# Result returned from "read" will be used unexpanded. That make "~" # Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that. # unusable. Going through one more level of expansion to handle that.
trt_install_path = os.path.realpath( trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path))
os.path.expanduser(trt_install_path))
def find_libs(search_path): def find_libs(search_path):
"""Search for libnvinfer.so in "search_path".""" """Search for libnvinfer.so in "search_path"."""
fl = set() fl = set()
if os.path.exists(search_path) and os.path.isdir(search_path): if os.path.exists(search_path) and os.path.isdir(search_path):
fl.update([os.path.realpath(os.path.join(search_path, x)) fl.update([
for x in os.listdir(search_path) if 'libnvinfer.so' in x]) os.path.realpath(os.path.join(search_path, x))
for x in os.listdir(search_path)
if 'libnvinfer.so' in x
])
return fl return fl
possible_files = find_libs(trt_install_path) possible_files = find_libs(trt_install_path)
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
"""Check the compatibility between tensorrt and cudnn/cudart libraries."""
ldd_bin = which('ldd') or '/usr/bin/ldd'
ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
cudnn = None
cudart = None
for line in ldd_out:
if 'libcudnn.so' in line:
cudnn = cudnn_pattern.search(line)
elif 'libcudart.so' in line:
cudart = cuda_pattern.search(line)
if cudnn and len(cudnn.group(1)):
cudnn = convert_version_to_int(cudnn.group(1))
if cudart and len(cudart.group(1)):
cudart = convert_version_to_int(cudart.group(1))
return (cudnn == cudnn_ver) and (cudart == cuda_ver)
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
highest_ver = [0, None, None] highest_ver = [0, None, None]
for lib_file in possible_files: for lib_file in possible_files:
if is_compatible(lib_file, cuda_ver, cudnn_ver): if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
matches = nvinfer_pattern.search(lib_file) matches = nvinfer_pattern.search(lib_file)
if len(matches.groups()) == 0: if len(matches.groups()) == 0:
continue continue
@ -1029,12 +1040,13 @@ def set_tf_tensorrt_install_path(environ_cp):
# Try another alternative from ldconfig. # Try another alternative from ldconfig.
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
ldconfig_output = run_shell([ldconfig_bin, '-p']) ldconfig_output = run_shell([ldconfig_bin, '-p'])
search_result = re.search( search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)',
'.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) ldconfig_output)
if search_result: if search_result:
libnvinfer_path_from_ldconfig = search_result.group(2) libnvinfer_path_from_ldconfig = search_result.group(2)
if os.path.exists(libnvinfer_path_from_ldconfig): if os.path.exists(libnvinfer_path_from_ldconfig):
if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver,
cudnn_ver):
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
tf_tensorrt_version = search_result.group(1) tf_tensorrt_version = search_result.group(1)
break break

View File

@ -154,6 +154,12 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "linux_s390x",
values = {"cpu": "s390x"},
visibility = ["//visibility:public"],
)
config_setting( config_setting(
name = "debug", name = "debug",
values = { values = {
@ -459,6 +465,15 @@ filegroup(
tf_cc_shared_object( tf_cc_shared_object(
name = "libtensorflow_framework.so", name = "libtensorflow_framework.so",
framework_so = [], framework_so = [],
linkopts = select({
"//tensorflow:darwin": [],
"//tensorflow:windows": [],
"//tensorflow:windows_msvc": [],
"//conditions:default": [
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
"$(location //tensorflow:tf_framework_version_script.lds)",
],
}),
linkstatic = 1, linkstatic = 1,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
@ -468,6 +483,7 @@ tf_cc_shared_object(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl", "//tensorflow/core:lib_internal_impl",
"//tensorflow/stream_executor:stream_executor_impl", "//tensorflow/stream_executor:stream_executor_impl",
"//tensorflow:tf_framework_version_script.lds",
] + tf_additional_binary_deps(), ] + tf_additional_binary_deps(),
) )
@ -571,3 +587,13 @@ py_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = ["//tensorflow/python:no_contrib"], deps = ["//tensorflow/python:no_contrib"],
) )
cc_library(
name = "grpc",
deps = ["@grpc"],
)
cc_library(
name = "grpc++",
deps = ["@grpc//:grpc++"],
)

View File

@ -2068,7 +2068,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
TF_Graph* graph, const TF_Buffer* graph_def, TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Status* status) { const TF_ImportGraphDefOptions* options, TF_Status* status) {
GraphDef def; GraphDef def;
if (!def.ParseFromArray(graph_def->data, graph_def->length)) { if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef"); status->status = InvalidArgument("Invalid GraphDef");
return nullptr; return nullptr;
} }
@ -2098,7 +2099,8 @@ void TF_GraphImportGraphDefWithReturnOutputs(
return; return;
} }
GraphDef def; GraphDef def;
if (!def.ParseFromArray(graph_def->data, graph_def->length)) { if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef"); status->status = InvalidArgument("Invalid GraphDef");
return; return;
} }

View File

@ -287,7 +287,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index(); const int64 result_index = compile_result.aot->result_buffer_index();
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
if (result_index < 0 || result_index > temp_sizes.size()) { if (result_index < 0 || result_index >= temp_sizes.size()) {
return errors::InvalidArgument("result index: ", result_index, return errors::InvalidArgument("result index: ", result_index,
" is outside the range of temp sizes: [0,", " is outside the range of temp sizes: [0,",
temp_sizes.size(), ")"); temp_sizes.size(), ")");

View File

@ -39,10 +39,10 @@ tf_cc_binary(
srcs = ["grpc_service_main.cc"], srcs = ["grpc_service_main.cc"],
deps = [ deps = [
":grpc_service", ":grpc_service",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@grpc//:grpc++",
], ],
) )
@ -54,6 +54,7 @@ tf_cc_test(
], ],
deps = [ deps = [
":grpc_stub", ":grpc_stub",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:literal_test_util",
@ -61,7 +62,6 @@ tf_cc_test(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@grpc//:grpc++",
], ],
) )
@ -71,9 +71,9 @@ cc_library(
hdrs = ["grpc_service.h"], hdrs = ["grpc_service.h"],
deps = [ deps = [
":xla_service_proto", ":xla_service_proto",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
"@grpc//:grpc++",
], ],
) )

View File

@ -2550,7 +2550,6 @@ cc_library(
name = "hlo_tfgraph_builder", name = "hlo_tfgraph_builder",
srcs = ["hlo_tfgraph_builder.cc"], srcs = ["hlo_tfgraph_builder.cc"],
hdrs = ["hlo_tfgraph_builder.h"], hdrs = ["hlo_tfgraph_builder.h"],
visibility = ["//tensorflow/compiler/xla/tools:__pkg__"],
deps = [ deps = [
":hlo", ":hlo",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",

View File

@ -1515,6 +1515,7 @@ bool HloInstruction::IdenticalSlowPath(
// Remaining instructions with special values. // Remaining instructions with special values.
case HloOpcode::kCall: case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kConditional: case HloOpcode::kConditional:
return eq_computations(true_computation(), other.true_computation()) && return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation()); eq_computations(false_computation(), other.false_computation());

View File

@ -924,6 +924,40 @@ TEST_F(HloInstructionTest, IdenticalInstructions) {
*HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
} }
TEST_F(HloInstructionTest, IdenticalCallInstructions) {
const char* const hlo_string = R"(
HloModule Module
subcomp1 (x: f32[]) -> f32[] {
x = f32[] parameter(0)
ROOT n = f32[] sine(x)
}
subcomp2 (x: f32[]) -> f32[] {
x = f32[] parameter(0)
ROOT n = f32[] cosine(x)
}
ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
p = f32[] parameter(0)
t1 = f32[] call(p), to_apply=subcomp1
t2 = f32[] call(p), to_apply=subcomp1
t3 = f32[] call(p), to_apply=subcomp2
ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_string));
auto* root = module->entry_computation()->root_instruction();
auto* t1 = root->operand(0);
auto* t2 = root->operand(1);
auto* t3 = root->operand(2);
EXPECT_TRUE(StructuralEqual(*t1, *t2));
EXPECT_FALSE(StructuralEqual(*t1, *t3));
}
TEST_F(HloInstructionTest, FunctionVisitor) { TEST_F(HloInstructionTest, FunctionVisitor) {
// Verify the function visitor HloInstruction::Accept visits all instructions // Verify the function visitor HloInstruction::Accept visits all instructions
// from a root properly given the following graph: // from a root properly given the following graph:

View File

@ -120,7 +120,10 @@ py_test(
name = "decorators_test", name = "decorators_test",
srcs = ["decorators_test.py"], srcs = ["decorators_test.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = ["no_windows"], tags = [
"no_pip",
"no_windows",
],
deps = [ deps = [
":converters", ":converters",
"//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/core:test_lib",

View File

@ -51,7 +51,7 @@ def for_stmt(iter_, extra_test, body, init_state):
Args: Args:
iter_: The entity being iterated over. iter_: The entity being iterated over.
extra_test: Callable with the state as arguments, and boolean return type. extra_test: Callable with the state as arguments, and boolean return type.
An additionnal loop condition. An additional loop condition.
body: Callable with the iterate and the state as arguments, and body: Callable with the iterate and the state as arguments, and
state as return type. The actual loop body. state as return type. The actual loop body.
init_state: Tuple containing the initial state. init_state: Tuple containing the initial state.

View File

@ -286,7 +286,7 @@ class Forward(object):
# TODO(alexbw): see if we can simplify by visiting breadth-first # TODO(alexbw): see if we can simplify by visiting breadth-first
def visit(self, node): def visit(self, node):
"""Depth-first walking the CFG, applying dataflow information propagtion.""" """Depth-first walking the CFG, applying dataflow info propagation."""
# node.value is None only for the exit CfgNode. # node.value is None only for the exit CfgNode.
if not node.value: if not node.value:
return return

View File

@ -218,7 +218,7 @@ class Base(gast.NodeTransformer):
# TODO(mdan): Once we have error tracing, we may be able to just go to SSA. # TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
def apply_to_single_assignments(self, targets, values, apply_fn): def apply_to_single_assignments(self, targets, values, apply_fn):
"""Applies a fuction to each individual assignment. """Applies a function to each individual assignment.
This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
It tries to break down the unpacking if possible. In effect, it has the same It tries to break down the unpacking if possible. In effect, it has the same
@ -246,7 +246,7 @@ class Base(gast.NodeTransformer):
targets field of an ast.Assign node. targets field of an ast.Assign node.
values: an AST node. values: an AST node.
apply_fn: a function of a single argument, which will be called with the apply_fn: a function of a single argument, which will be called with the
respective nodes of each single assignment. The signaure is respective nodes of each single assignment. The signature is
apply_fn(target, value), no return value. apply_fn(target, value), no return value.
""" """
if not isinstance(targets, (list, tuple)): if not isinstance(targets, (list, tuple)):

View File

@ -336,40 +336,14 @@ endif()
# MKL Support # MKL Support
if (tensorflow_ENABLE_MKL_SUPPORT) if (tensorflow_ENABLE_MKL_SUPPORT)
add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) add_definitions(-DINTEL_MKL -DEIGEN_USE_VML)
if (WIN32) include(mkl)
find_path(MKL_HOME_PLATFORM mkl list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES})
PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination)
$ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ include_directories(${mkl_INCLUDE_DIRS})
PATH_SUFFIXES windows)
set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include)
set(MKL_LINK_DIRS
${MKL_HOME_PLATFORM}/mkl/lib/intel64
${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt
${MKL_HOME_PLATFORM}/compiler/lib/intel64
${MKL_HOME_PLATFORM}/mkl/tools/builder/lib)
set(MKL_REDIST_DLL_DIRS
${MKL_HOME_PLATFORM}/redist/intel64/mkl
${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt
${MKL_HOME_PLATFORM}/redist/intel64/compiler)
list(APPEND tensorflow_EXTERNAL_LIBRARIES
mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64)
endif()
if (UNIX)
# Fix me: complete the path on linux
find_path(MKL_HOME_PLATFORM mkl
HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../
$ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../
PATH_SUFFIXES linux)
set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include)
set(MKL_LINK_DIRS) # incompleted
set(MKL_REDIST_SO_DIRS) # incompleted
endif()
include_directories(${MKL_INCLUDE_DIRS})
link_directories(${MKL_LINK_DIRS})
if (tensorflow_ENABLE_MKLDNN_SUPPORT) if (tensorflow_ENABLE_MKLDNN_SUPPORT)
include(mkldnn) include(mkldnn)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
include_directories(${mkldnn_INCLUDE_DIRS}) include_directories(${mkldnn_INCLUDE_DIRS})
else (tensorflow_ENABLE_MKLDNN_SUPPORT) else (tensorflow_ENABLE_MKLDNN_SUPPORT)
add_definitions(-DINTEL_MKL_ML) add_definitions(-DINTEL_MKL_ML)

View File

@ -16,15 +16,15 @@ include (ExternalProject)
set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion)
set(double_conversion_URL https://github.com/google/double-conversion.git) set(double_conversion_URL https://github.com/google/double-conversion.git)
set(double_conversion_TAG 5664746) set(double_conversion_TAG 3992066a95b823efc8ccc1baf82a1cfc73f6e9b8)
set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR})
set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so)
set(double_conversion_INCLUDES ${double_conversion_BUILD}) set(double_conversion_INCLUDES ${double_conversion_BUILD})
if(WIN32) if(WIN32)
set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib) set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/$(Configuration)/double-conversion.lib)
else() else()
set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a) set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/libdouble-conversion.a)
endif() endif()
set(double_conversion_HEADERS set(double_conversion_HEADERS

View File

@ -0,0 +1,68 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
include (ExternalProject)
# NOTE: Different from mkldnn.cmake, this file is meant to download mkl libraries
set(mkl_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include)
set(mkl_BIN_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/bin)
set(mkl_WIN mklml_win_2018.0.3.20180406.zip) # match for v0.14
set(mkl_MAC mklml_mac_2018.0.3.20180406.tgz)
set(mkl_LNX mklml_lnx_2018.0.3.20180406.tgz)
set(mkl_TAG v0.14)
set(mkl_URL https://github.com/intel/mkl-dnn/releases)
if (WIN32)
set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_WIN})
list(APPEND mkl_STATIC_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.lib)
list(APPEND mkl_STATIC_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.lib)
list(APPEND mkl_SHARED_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.dll)
list(APPEND mkl_SHARED_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.dll)
elseif (UNIX)
set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_LNX})
list(APPEND mkl_SHARED_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5.so)
list(APPEND mkl_SHARED_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_gnu.so)
list(APPEND mkl_SHARED_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_intel.so)
elseif (APPLE)
set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_MAC})
#TODO need more information
endif ()
ExternalProject_Add(mkl
PREFIX mkl
URL ${mkl_DOWNLOAD_URL}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND "")
# put mkl dynamic libraries in one bin directory
add_custom_target(mkl_create_destination_dir
COMMAND ${CMAKE_COMMAND} -E make_directory ${mkl_BIN_DIRS}
DEPENDS mkl)
add_custom_target(mkl_copy_shared_to_destination DEPENDS mkl_create_destination_dir)
foreach(dll_file ${mkl_SHARED_LIBRARIES})
add_custom_command(TARGET mkl_copy_shared_to_destination PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dll_file} ${mkl_BIN_DIRS})
endforeach()

View File

@ -22,8 +22,11 @@ set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291)
if(WIN32) if(WIN32)
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib) set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib)
set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.dll)
set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release)
else() else()
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib) set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib)
set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.dll)
endif() endif()
else() else()
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a) set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a)
@ -31,6 +34,7 @@ endif()
ExternalProject_Add(mkldnn ExternalProject_Add(mkldnn
PREFIX mkldnn PREFIX mkldnn
DEPENDS mkl
GIT_REPOSITORY ${mkldnn_URL} GIT_REPOSITORY ${mkldnn_URL}
GIT_TAG ${mkldnn_TAG} GIT_TAG ${mkldnn_TAG}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
@ -40,5 +44,11 @@ ExternalProject_Add(mkldnn
CMAKE_CACHE_ARGS CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DMKLINC:STRING=${MKL_INCLUDE_DIRS} -DMKLINC:STRING=${mkl_INCLUDE_DIRS}
) )
# since mkldnn depends on mkl, copy the mkldnn.dll together with mklml.dll to mkl_bin_dirs
add_custom_target(mkldnn_copy_shared_to_destination DEPENDS mkldnn)
add_custom_command(TARGET mkldnn_copy_shared_to_destination PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${mkldnn_SHARED_LIBRARIES} ${mkl_BIN_DIRS})

View File

@ -755,26 +755,65 @@ set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt")
file(WRITE "${api_init_list_file}" "${api_init_files}") file(WRITE "${api_init_list_file}" "${api_init_files}")
# Run create_python_api.py to generate __init__.py files. # Run create_python_api.py to generate __init__.py files.
add_custom_command(
OUTPUT ${api_init_files}
DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
# tensorflow/__init__.py depends on files generated in this step. So, remove it while ### TODO
# this step is running since the files aren't there yet. # In order to download and compile MKL/MKL-DNN automatically in cmake script, mkl-built libraries should be added to system path
COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # to be loaded by python executor. However `add_custom_command` has an issue with `COMMAND ${CMAKE_COMMAND} -E env PATH=`, where
# arguments of multiple paths (such as D:/;D:/mkl) will be parsed in to seperate string without semicolon and that command fail to
# recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue.
# To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem,
# and should be removed if the path issue can be resolved.
###
# Run create_python_api.py to generate API init files. if (tensorflow_ENABLE_MKL_SUPPORT)
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} # add mkl dist dlls to system path for python
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" # TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths,
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" # so we have to specify only one path in it to work around the issue. We need this if/else
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" # to protect overwriting CUDA environments
"--package=tensorflow.python" set(PY_RUNTIME_ENV ${mkl_BIN_DIRS})
"--apiname=tensorflow" add_custom_command(
"${api_init_list_file}" OUTPUT ${api_init_files}
DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
COMMENT "Generating __init__.py files for Python API." # tensorflow/__init__.py depends on files generated in this step. So, remove it while
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" # this step is running since the files aren't there yet.
) COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
# Run create_python_api.py to generate API init files.
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE}
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
"--package=tensorflow.python"
"--apiname=tensorflow"
"${api_init_list_file}"
COMMENT "Generating __init__.py files for Python API."
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
VERBATIM
)
else (tensorflow_ENABLE_MKL_SUPPORT)
add_custom_command(
OUTPUT ${api_init_files}
DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
# tensorflow/__init__.py depends on files generated in this step. So, remove it while
# this step is running since the files aren't there yet.
COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
# Run create_python_api.py to generate API init files.
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
"--package=tensorflow.python"
"--apiname=tensorflow"
"${api_init_list_file}"
COMMENT "Generating __init__.py files for Python API."
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
)
endif (tensorflow_ENABLE_MKL_SUPPORT)
add_custom_target(tf_python_api SOURCES ${api_init_files}) add_custom_target(tf_python_api SOURCES ${api_init_files})
add_dependencies(tf_python_api tf_python_ops) add_dependencies(tf_python_api tf_python_ops)

View File

@ -145,3 +145,8 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/
# unsupported Eigen directory # unsupported Eigen directory
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/
DESTINATION include/unsupported/Eigen) DESTINATION include/unsupported/Eigen)
# mkl
if (tensorflow_ENABLE_MKL_SUPPORT)
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/
DESTINATION include/mkl)
endif (tensorflow_ENABLE_MKL_SUPPORT)

View File

@ -46,7 +46,7 @@ document.
Imagine that we want to constrain the recall of a binary classifier to be at Imagine that we want to constrain the recall of a binary classifier to be at
least 90%. Since the recall is proportional to the number of true positive least 90%. Since the recall is proportional to the number of true positive
classifications, which itself is a sum of indicator functions, this constraint classifications, which itself is a sum of indicator functions, this constraint
is non-differentible, and therefore cannot be used in a problem that will be is non-differentiable, and therefore cannot be used in a problem that will be
optimized using a (stochastic) gradient-based algorithm. optimized using a (stochastic) gradient-based algorithm.
For this and similar problems, TFCO supports so-called *proxy constraints*, For this and similar problems, TFCO supports so-called *proxy constraints*,

View File

@ -169,8 +169,8 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
del old_inactive # Needed by the condition, but not the body. del old_inactive # Needed by the condition, but not the body.
iteration += 1 iteration += 1
scale = (1.0 - standard_ops.reduce_sum( scale = (1.0 - standard_ops.reduce_sum(
matrix, axis=0, keep_dims=True)) / standard_ops.maximum( matrix, axis=0, keepdims=True)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True)) 1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True))
matrix += scale * inactive matrix += scale * inactive
new_inactive = standard_ops.to_float(matrix > 0) new_inactive = standard_ops.to_float(matrix > 0)
matrix *= new_inactive matrix *= new_inactive
@ -206,10 +206,10 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
# For numerical reasons, make sure that the largest matrix element is zero # For numerical reasons, make sure that the largest matrix element is zero
# before exponentiating. # before exponentiating.
log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True) log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True)
log_matrix -= standard_ops.log( log_matrix -= standard_ops.log(
standard_ops.reduce_sum( standard_ops.reduce_sum(
standard_ops.exp(log_matrix), axis=0, keep_dims=True)) standard_ops.exp(log_matrix), axis=0, keepdims=True))
return log_matrix return log_matrix

View File

@ -58,6 +58,7 @@ class SlideDatasetTest(test.TestCase):
[t.shape.as_list() for t in get_next]) [t.shape.as_list() for t in get_next])
with self.test_session() as sess: with self.test_session() as sess:
# stride < window_size.
# Slide over a finite input, where the window_size divides the # Slide over a finite input, where the window_size divides the
# total number of elements. # total number of elements.
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7})
@ -71,11 +72,9 @@ class SlideDatasetTest(test.TestCase):
result_component[j]) result_component[j])
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next) sess.run(get_next)
# Slide over a finite input, where the window_size does not # Slide over a finite input, where the window_size does not
# divide the total number of elements. # divide the total number of elements.
sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9})
num_batches = (20 * 7 - 17) // 9 + 1 num_batches = (20 * 7 - 17) // 9 + 1
for i in range(num_batches): for i in range(num_batches):
result = sess.run(get_next) result = sess.run(get_next)
@ -86,6 +85,41 @@ class SlideDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError): with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next) sess.run(get_next)
# stride == window_size.
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 14})
num_batches = 20 * 7 // 14
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i*14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# stride > window_size.
sess.run(init_op, feed_dict={count: 20, window_size: 10, stride: 14})
num_batches = 20 * 7 // 14
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(10):
self.assertAllEqual(component[(i*14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Drop the last batch which is smaller than window_size.
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 19})
num_batches = (20 * 7 - 7) // 19 # = 19 * 7 // 19
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i*19 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Slide over a finite input, which is less than window_size, # Slide over a finite input, which is less than window_size,
# should fail straight away. # should fail straight away.
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4})
@ -108,10 +142,6 @@ class SlideDatasetTest(test.TestCase):
# Invalid stride should be an initialization time error. # Invalid stride should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5})
def assertSparseValuesEqual(self, a, b): def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices) self.assertAllEqual(a.indices, b.indices)

View File

@ -86,7 +86,7 @@ def sliding_window_batch(window_size, stride=1):
elements in the sliding window. elements in the sliding window.
stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
steps moving the sliding window forward for one iteration. The default steps moving the sliding window forward for one iteration. The default
is `1`. It must be in `[1, window_size)`. is `1`. It must be positive.
Returns: Returns:
A `Dataset` transformation function, which can be passed to A `Dataset` transformation function, which can be passed to

View File

@ -0,0 +1,909 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "nmt_with_attention.ipynb",
"version": "0.3.2",
"views": {},
"default_view": {},
"provenance": [
{
"file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U",
"timestamp": 1527858391290
},
{
"file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv",
"timestamp": 1527776041613
}
],
"private_outputs": true,
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"metadata": {
"id": "AOpGoE2T-YXS",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"##### Copyright 2018 The TensorFlow Authors.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\").\n",
"\n",
"# Neural Machine Translation with Attention\n",
"\n",
"<table align=\"left\"><td>\n",
"<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\">\n",
" <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n",
"</td><td>\n",
"<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on Github</a></td></table>"
]
},
{
"metadata": {
"id": "CiwtNgENbx2g",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n",
"\n",
"After training the model in this notebook, you will be able to input a Spanish sentence, such as *\"¿todavia estan en casa?\"*, and return the English translation: *\"are you still at home?\"*\n",
"\n",
"The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n",
"\n",
"<img src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\">\n",
"\n",
"Note: This example takes approximately 10 mintues to run on a single P100 GPU."
]
},
{
"metadata": {
"id": "tnxXKDjq3jEL",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"from __future__ import absolute_import, division, print_function\n",
"\n",
"# Import TensorFlow >= 1.9 and enable eager execution\n",
"import tensorflow as tf\n",
"\n",
"tf.enable_eager_execution()\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"import unicodedata\n",
"import re\n",
"import numpy as np\n",
"import os\n",
"import time\n",
"\n",
"print(tf.__version__)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "wfodePkj3jEa",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Download and prepare the dataset\n",
"\n",
"We'll use a language dataset provided by http://www.manythings.org/anki/. This dataset contains language translation pairs in the format:\n",
"\n",
"```\n",
"May I borrow this book?\t¿Puedo tomar prestado este libro?\n",
"```\n",
"\n",
"There are a variety of languages available, but we'll use the English-Spanish dataset. For convenience, we've hosted a copy of this dataset on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps we'll take to prepare the data:\n",
"\n",
"1. Add a *start* and *end* token to each sentence.\n",
"2. Clean the sentences by removing special characters.\n",
"3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n",
"4. Pad each sentence to a maximum length."
]
},
{
"metadata": {
"id": "kRVATYOgJs1b",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# Download the file\n",
"path_to_zip = tf.keras.utils.get_file(\n",
" 'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip', \n",
" extract=True)\n",
"\n",
"path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\""
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "rd0jw-eC3jEh",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# Converts the unicode file to ascii\n",
"def unicode_to_ascii(s):\n",
" return ''.join(c for c in unicodedata.normalize('NFD', s)\n",
" if unicodedata.category(c) != 'Mn')\n",
"\n",
"\n",
"def preprocess_sentence(w):\n",
" w = unicode_to_ascii(w.lower().strip())\n",
" \n",
" # creating a space between a word and the punctuation following it\n",
" # eg: \"he is a boy.\" => \"he is a boy .\" \n",
" # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n",
" w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
" w = re.sub(r'[\" \"]+', \" \", w)\n",
" \n",
" # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n",
" w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
" \n",
" w = w.rstrip().strip()\n",
" \n",
" # adding a start and an end token to the sentence\n",
" # so that the model know when to start and stop predicting.\n",
" w = '<start> ' + w + ' <end>'\n",
" return w"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "OHn4Dct23jEm",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# 1. Remove the accents\n",
"# 2. Clean the sentences\n",
"# 3. Return word pairs in the format: [ENGLISH, SPANISH]\n",
"def create_dataset(path, num_examples):\n",
" lines = open(path, encoding='UTF-8').read().strip().split('\\n')\n",
" \n",
" word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n",
" \n",
" return word_pairs"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "9xbqO7Iie9bb",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n",
"# (e.g., 5 -> \"dad\") for each language,\n",
"class LanguageIndex():\n",
" def __init__(self, lang):\n",
" self.lang = lang\n",
" self.word2idx = {}\n",
" self.idx2word = {}\n",
" self.vocab = set()\n",
" \n",
" self.create_index()\n",
" \n",
" def create_index(self):\n",
" for phrase in self.lang:\n",
" self.vocab.update(phrase.split(' '))\n",
" \n",
" self.vocab = sorted(self.vocab)\n",
" \n",
" self.word2idx['<pad>'] = 0\n",
" for index, word in enumerate(self.vocab):\n",
" self.word2idx[word] = index + 1\n",
" \n",
" for word, index in self.word2idx.items():\n",
" self.idx2word[index] = word"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "eAY9k49G3jE_",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def max_length(tensor):\n",
" return max(len(t) for t in tensor)\n",
"\n",
"\n",
"def load_dataset(path, num_examples):\n",
" # creating cleaned input, output pairs\n",
" pairs = create_dataset(path, num_examples)\n",
"\n",
" # index language using the class defined above \n",
" inp_lang = LanguageIndex(sp for en, sp in pairs)\n",
" targ_lang = LanguageIndex(en for en, sp in pairs)\n",
" \n",
" # Vectorize the input and target languages\n",
" \n",
" # Spanish sentences\n",
" input_tensor = [[inp_lang.word2idx[s] for s in sp.split(' ')] for en, sp in pairs]\n",
" \n",
" # English sentences\n",
" target_tensor = [[targ_lang.word2idx[s] for s in en.split(' ')] for en, sp in pairs]\n",
" \n",
" # Calculate max_length of input and output tensor\n",
" # Here, we'll set those to the longest sentence in the dataset\n",
" max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)\n",
" \n",
" # Padding the input and output tensor to the maximum length\n",
" input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, \n",
" maxlen=max_length_inp,\n",
" padding='post')\n",
" \n",
" target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor, \n",
" maxlen=max_length_tar, \n",
" padding='post')\n",
" \n",
" return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "GOi42V79Ydlr",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### Limit the size of the dataset to experiment faster (optional)\n",
"\n",
"Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):"
]
},
{
"metadata": {
"id": "cnxC7q-j3jFD",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# Try experimenting with the size of that dataset\n",
"num_examples = 30000\n",
"input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "4QILQkOs3jFG",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# Creating training and validation sets using an 80-20 split\n",
"input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n",
"\n",
"# Show length\n",
"len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "rgCLkfv5uO3d",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"### Create a tf.data dataset"
]
},
{
"metadata": {
"id": "TqHsArVZ3jFS",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"BUFFER_SIZE = len(input_tensor_train)\n",
"BATCH_SIZE = 64\n",
"embedding_dim = 256\n",
"units = 1024\n",
"vocab_inp_size = len(inp_lang.word2idx)\n",
"vocab_tar_size = len(targ_lang.word2idx)\n",
"\n",
"dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
"dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "TNfHIF71ulLu",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Write the encoder and decoder model\n",
"\n",
"Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n",
"\n",
"<img src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\">\n",
"\n",
"The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n",
"\n",
"Here are the equations that are implemented:\n",
"\n",
"<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\">\n",
"<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\">\n",
"\n",
"We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n",
"\n",
"* FC = Fully connected (dense) layer\n",
"* EO = Encoder output\n",
"* H = hidden state\n",
"* X = input to the decoder\n",
"\n",
"And the pseudo-code:\n",
"\n",
"* `score = FC(tanh(FC(EO) + FC(H)))`\n",
"* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n",
"* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n",
"* `embedding output` = The input to the decoder X is passed through an embedding layer.\n",
"* `merged vector = concat(embedding output, context vector)`\n",
"* This merged vector is then given to the GRU\n",
" \n",
"The shapes of all the vectors at each step have been specified in the comments in the code:"
]
},
{
"metadata": {
"id": "avyJ_4VIUoHb",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def gru(units):\n",
" # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n",
" # the code automatically does that.\n",
" if tf.test.is_gpu_available():\n",
" return tf.keras.layers.CuDNNGRU(units, \n",
" return_sequences=True, \n",
" return_state=True, \n",
" recurrent_initializer='glorot_uniform')\n",
" else:\n",
" return tf.keras.layers.GRU(units, \n",
" return_sequences=True, \n",
" return_state=True, \n",
" recurrent_activation='sigmoid', \n",
" recurrent_initializer='glorot_uniform')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "nZ2rI24i3jFg",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"class Encoder(tf.keras.Model):\n",
" def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
" super(Encoder, self).__init__()\n",
" self.batch_sz = batch_sz\n",
" self.enc_units = enc_units\n",
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
" self.gru = gru(self.enc_units)\n",
" \n",
" def call(self, x, hidden):\n",
" x = self.embedding(x)\n",
" output, state = self.gru(x, initial_state = hidden) \n",
" return output, state\n",
" \n",
" def initialize_hidden_state(self):\n",
" return tf.zeros((self.batch_sz, self.enc_units))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "yJ_B3mhW3jFk",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"class Decoder(tf.keras.Model):\n",
" def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
" super(Decoder, self).__init__()\n",
" self.batch_sz = batch_sz\n",
" self.dec_units = dec_units\n",
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
" self.gru = gru(self.dec_units)\n",
" self.fc = tf.keras.layers.Dense(vocab_size)\n",
" \n",
" # used for attention\n",
" self.W1 = tf.keras.layers.Dense(self.dec_units)\n",
" self.W2 = tf.keras.layers.Dense(self.dec_units)\n",
" self.V = tf.keras.layers.Dense(1)\n",
" \n",
" def call(self, x, hidden, enc_output):\n",
" # enc_output shape == (batch_size, max_length, hidden_size)\n",
" \n",
" # hidden shape == (batch_size, hidden size)\n",
" # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
" # we are doing this to perform addition to calculate the score\n",
" hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
" \n",
" # score shape == (batch_size, max_length, hidden_size)\n",
" score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n",
" \n",
" # attention_weights shape == (batch_size, max_length, 1)\n",
" # we get 1 at the last axis because we are applying score to self.V\n",
" attention_weights = tf.nn.softmax(self.V(score), axis=1)\n",
" \n",
" # context_vector shape after sum == (batch_size, hidden_size)\n",
" context_vector = attention_weights * enc_output\n",
" context_vector = tf.reduce_sum(context_vector, axis=1)\n",
" \n",
" # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
" x = self.embedding(x)\n",
" \n",
" # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
" x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
" \n",
" # passing the concatenated vector to the GRU\n",
" output, state = self.gru(x)\n",
" \n",
" # output shape == (batch_size * max_length, hidden_size)\n",
" output = tf.reshape(output, (-1, output.shape[2]))\n",
" \n",
" # output shape == (batch_size * max_length, vocab)\n",
" x = self.fc(output)\n",
" \n",
" return x, state, attention_weights\n",
" \n",
" def initialize_hidden_state(self):\n",
" return tf.zeros((self.batch_sz, self.dec_units))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "P5UY8wko3jFp",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
"decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "_ch_71VbIRfK",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Define the optimizer and the loss function"
]
},
{
"metadata": {
"id": "WmTHr5iV3jFr",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"optimizer = tf.train.AdamOptimizer()\n",
"\n",
"\n",
"def loss_function(real, pred):\n",
" mask = 1 - np.equal(real, 0)\n",
" loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
" return tf.reduce_mean(loss_)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "hpObfY22IddU",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Training\n",
"\n",
"1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n",
"2. The encoder output, encoder hidden state and the decoder input (which is the *start token*) is passed to the decoder.\n",
"3. The decoder returns the *predictions* and the *decoder hidden state*.\n",
"4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n",
"5. Use *teacher forcing* to decide the next input to the decoder.\n",
"6. *Teacher forcing* is the technique where the *target word* is passed as the *next input* to the decoder.\n",
"7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate."
]
},
{
"metadata": {
"id": "ddefjBMa3jF0",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"EPOCHS = 10\n",
"\n",
"for epoch in range(EPOCHS):\n",
" start = time.time()\n",
" \n",
" hidden = encoder.initialize_hidden_state()\n",
" total_loss = 0\n",
" \n",
" for (batch, (inp, targ)) in enumerate(dataset):\n",
" loss = 0\n",
" \n",
" with tf.GradientTape() as tape:\n",
" enc_output, enc_hidden = encoder(inp, hidden)\n",
" \n",
" dec_hidden = enc_hidden\n",
" \n",
" dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * BATCH_SIZE, 1) \n",
" \n",
" # Teacher forcing - feeding the target as the next input\n",
" for t in range(1, targ.shape[1]):\n",
" # passing enc_output to the decoder\n",
" predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n",
" \n",
" loss += loss_function(targ[:, t], predictions)\n",
" \n",
" # using teacher forcing\n",
" dec_input = tf.expand_dims(targ[:, t], 1)\n",
" \n",
" total_loss += (loss / int(targ.shape[1]))\n",
" \n",
" variables = encoder.variables + decoder.variables\n",
" \n",
" gradients = tape.gradient(loss, variables)\n",
" \n",
" optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n",
"\n",
" if batch % 100 == 0:\n",
" print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
" batch,\n",
" loss.numpy() / int(targ.shape[1])))\n",
" \n",
" print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
" total_loss/len(input_tensor)))\n",
" print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "mU3Ce8M6I3rz",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Translate\n",
"\n",
"* The evaluate function is similar to the training loop, except we don't use *teacher forcing* here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n",
"* Stop predicting when the model predicts the *end token*.\n",
"* And store the *attention weights for every time step*.\n",
"\n",
"Note: The encoder output is calculated only once for one input."
]
},
{
"metadata": {
"id": "EbQpyYs13jF_",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
" attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
" \n",
" sentence = preprocess_sentence(sentence)\n",
"\n",
" inputs = [inp_lang.word2idx[i] for i in sentence.split(' ')]\n",
" inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')\n",
" inputs = tf.convert_to_tensor(inputs)\n",
" \n",
" result = ''\n",
"\n",
" hidden = [tf.zeros((1, units))]\n",
" enc_out, enc_hidden = encoder(inputs, hidden)\n",
"\n",
" dec_hidden = enc_hidden\n",
" dec_input = tf.expand_dims([targ_lang.word2idx['<start>']], 0)\n",
"\n",
" for t in range(max_length_targ):\n",
" predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n",
" \n",
" # storing the attention weigths to plot later on\n",
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
" attention_plot[t] = attention_weights.numpy()\n",
"\n",
" predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
"\n",
" result += targ_lang.idx2word[predicted_id] + ' '\n",
"\n",
" if targ_lang.idx2word[predicted_id] == '<end>':\n",
" return result, sentence, attention_plot\n",
" \n",
" # the predicted ID is fed back into the model\n",
" dec_input = tf.expand_dims([predicted_id], 0)\n",
"\n",
" return result, sentence, attention_plot"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "s5hQWlbN3jGF",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# function for plotting the attention weights\n",
"def plot_attention(attention, sentence, predicted_sentence):\n",
" fig = plt.figure(figsize=(10,10))\n",
" ax = fig.add_subplot(1, 1, 1)\n",
" ax.matshow(attention, cmap='viridis')\n",
" \n",
" fontdict = {'fontsize': 14}\n",
" \n",
" ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n",
" ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n",
"\n",
" plt.show()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "sl9zUHzg3jGI",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
" result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n",
" \n",
" print('Input: {}'.format(sentence))\n",
" print('Predicted translation: {}'.format(result))\n",
" \n",
" attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n",
" plot_attention(attention_plot, sentence.split(' '), result.split(' '))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "WrAM0FDomq3E",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "zSx2iM36EZQZ",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "A3LLCx3ZE0Ls",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "DUQVLVqUE1YW",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# wrong translation\n",
"translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "RTe5P5ioMJwN",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Next steps\n",
"\n",
"* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n",
"* Experiment with training on a larger dataset, or using more epochs\n"
]
}
]
}

View File

@ -24,6 +24,7 @@ from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import head from tensorflow.python.estimator.canned import head
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import metrics as metrics_lib
@ -182,7 +183,10 @@ class GANHead(head._Head): # pylint: disable=protected-access
if mode == model_fn_lib.ModeKeys.PREDICT: if mode == model_fn_lib.ModeKeys.PREDICT:
return model_fn_lib.EstimatorSpec( return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.PREDICT, mode=model_fn_lib.ModeKeys.PREDICT,
predictions=gan_model.generated_data) predictions=gan_model.generated_data,
export_outputs={
'predict': export_output.PredictOutput(gan_model.generated_data)
})
elif mode == model_fn_lib.ModeKeys.EVAL: elif mode == model_fn_lib.ModeKeys.EVAL:
gan_loss = self.create_loss( gan_loss = self.create_loss(
features=None, mode=mode, logits=gan_model, labels=None) features=None, mode=mode, logits=gan_model, labels=None)

View File

@ -26,8 +26,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training from tensorflow.python.training import training
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument
return math_ops.reduce_sum(gan_model.discriminator_real_outputs - return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
@ -71,13 +74,15 @@ class GANHeadTest(test.TestCase):
return {} return {}
def _test_modes_helper(self, mode): def _test_modes_helper(self, mode):
self.gan_head.create_estimator_spec( return self.gan_head.create_estimator_spec(
features=None, features=None,
mode=mode, mode=mode,
logits=get_gan_model()) logits=get_gan_model())
def test_modes_predict(self): def test_modes_predict(self):
self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict'),
spec.export_outputs.keys())
def test_modes_eval(self): def test_modes_eval(self):
self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) self._test_modes_helper(model_fn_lib.ModeKeys.EVAL)

View File

@ -57,7 +57,7 @@ Status GdrServer::Init() {
new GdrWorker(env, remote_memory_manager_.get())); new GdrWorker(env, remote_memory_manager_.get()));
}; };
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func)); GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func));
return remote_memory_manager_->Init(); return remote_memory_manager_->Init();
} }

View File

@ -26,7 +26,7 @@ namespace optimized_ops {
// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
// Jetson TX-2. This compiler does not support the offsetof() macro. // Jetson TX-2. This compiler does not support the offsetof() macro.
#if defined(__aarch64__) && !defined(GOOGLE_L4T) #if defined(__aarch64__) && !defined(GOOGLE_L4T)
#include <stddef.h>
// clang-format gets confused with this file and ends up formatting lines to // clang-format gets confused with this file and ends up formatting lines to
// be larger than 80 characters. Turn off here and back on at the end of the // be larger than 80 characters. Turn off here and back on at the end of the
// file. // file.

View File

@ -19,7 +19,9 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
// Place `<locale>` before <Python.h> to avoid build failures in macOS.
#include <Python.h> #include <Python.h>
#include <locale>
// We forward declare TFLite classes here to avoid exposing them to SWIG. // We forward declare TFLite classes here to avoid exposing them to SWIG.
namespace tflite { namespace tflite {

View File

@ -29,6 +29,7 @@ py_library(
"python/training/reg_adagrad_optimizer.py", "python/training/reg_adagrad_optimizer.py",
"python/training/sign_decay.py", "python/training/sign_decay.py",
"python/training/variable_clipping_optimizer.py", "python/training/variable_clipping_optimizer.py",
"python/training/weight_decay_optimizers.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
@ -198,6 +199,25 @@ py_test(
], ],
) )
py_test(
name = "weight_decay_optimizers_test",
srcs = ["python/training/weight_decay_optimizers_test.py"],
srcs_version = "PY2AND3",
deps = [
":opt_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:session",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
tf_py_test( tf_py_test(
name = "drop_stale_gradient_optimizer_test", name = "drop_stale_gradient_optimizer_test",
srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], srcs = ["python/training/drop_stale_gradient_optimizer_test.py"],

View File

@ -22,16 +22,17 @@ from __future__ import print_function
from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.adamax import *
from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.addsign import *
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import *
from tensorflow.contrib.opt.python.training.ggt import *
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
from tensorflow.contrib.opt.python.training.model_average_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import * from tensorflow.contrib.opt.python.training.nadam_optimizer import *
from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
from tensorflow.contrib.opt.python.training.model_average_optimizer import *
from tensorflow.contrib.opt.python.training.ggt import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented
@ -47,6 +48,10 @@ _allowed_symbols = [
'LazyAdamOptimizer', 'LazyAdamOptimizer',
'NadamOptimizer', 'NadamOptimizer',
'MovingAverageOptimizer', 'MovingAverageOptimizer',
'MomentumWOptimizer',
'AdamWOptimizer',
'DecoupledWeightDecayExtension',
'extend_with_decoupled_weight_decay',
'ScipyOptimizerInterface', 'ScipyOptimizerInterface',
'VariableClippingOptimizer', 'VariableClippingOptimizer',
'MultitaskOptimizerWrapper', 'MultitaskOptimizerWrapper',

View File

@ -0,0 +1,362 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class to make optimizers weight decay ready."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import adam
from tensorflow.python.training import momentum as momentum_opt
from tensorflow.python.training import optimizer
from tensorflow.python.util.tf_export import tf_export
class DecoupledWeightDecayExtension(object):
"""This class allows to extend optimizers with decoupled weight decay.
It implements the decoupled weight decay described by Loshchilov & Hutter
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
decoupled from the optimization steps w.r.t. to the loss function.
For SGD variants, this simplifies hyperparameter search since it decouples
the settings of weight decay and learning rate.
For adaptive gradient algorithms, it regularizes variables with large
gradients more than L2 regularization would, which was shown to yield better
training loss and generalization error in the paper above.
This class alone is not an optimizer but rather extends existing
optimizers with decoupled weight decay. We explicitly define the two examples
used in the above paper (SGDW and AdamW), but in general this can extend
any OptimizerX by using
`extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`.
In order for it to work, it must be the first class the Optimizer with
weight decay inherits from, e.g.
```python
class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
def __init__(self, weight_decay, *args, **kwargs):
super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs).
```
Note that this extension decays weights BEFORE applying the update based
on the gradient, i.e. this extension only has the desired behaviour for
optimizers which do not depend on the value of'var' in the update step!
"""
def __init__(self, weight_decay, **kwargs):
"""Construct the extension class that adds weight decay to an optimizer.
Args:
weight_decay: A `Tensor` or a floating point value, the factor by which
a variable is decayed in the update step.
**kwargs: Optional list or tuple or set of `Variable` objects to
decay.
"""
self._decay_var_list = None # is set in minimize or apply_gradients
self._weight_decay = weight_decay
# The tensors are initialized in call to _prepare
self._weight_decay_tensor = None
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
def minimize(self, loss, global_step=None, var_list=None,
gate_gradients=optimizer.Optimizer.GATE_OP,
aggregation_method=None, colocate_gradients_with_ops=False,
name=None, grad_loss=None, decay_var_list=None):
"""Add operations to minimize `loss` by updating `var_list` with decay.
This function is the same as Optimizer.minimize except that it allows to
specify the variables that should be decayed using decay_var_list.
If decay_var_list is None, all variables in var_list are decayed.
For more information see the documentation of Optimizer.minimize.
Args:
loss: A `Tensor` containing the value to minimize.
global_step: Optional `Variable` to increment by one after the
variables have been updated.
var_list: Optional list or tuple of `Variable` objects to update to
minimize `loss`. Defaults to the list of variables collected in
the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
gate_gradients: How to gate the computation of gradients. Can be
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: If True, try colocating gradients with
the corresponding op.
name: Optional name for the returned operation.
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
decay_var_list: Optional list of decay variables.
Returns:
An Operation that updates the variables in `var_list`. If `global_step`
was not `None`, that operation also increments `global_step`.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).minimize(
loss, global_step=global_step, var_list=var_list,
gate_gradients=gate_gradients, aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops, name=name,
grad_loss=grad_loss)
def apply_gradients(self, grads_and_vars, global_step=None, name=None,
decay_var_list=None):
"""Apply gradients to variables and decay the variables.
This function is the same as Optimizer.apply_gradients except that it
allows to specify the variables that should be decayed using
decay_var_list. If decay_var_list is None, all variables in var_list
are decayed.
For more information see the documentation of Optimizer.apply_gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
decay_var_list: Optional list of decay variables.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).apply_gradients(
grads_and_vars, global_step=global_step, name=name)
def _prepare(self):
weight_decay = self._weight_decay
if callable(weight_decay):
weight_decay = weight_decay()
self._weight_decay_tensor = ops.convert_to_tensor(
weight_decay, name="weight_decay")
# Call the optimizers _prepare function.
super(DecoupledWeightDecayExtension, self)._prepare()
def _decay_weights_op(self, var):
if not self._decay_var_list or var in self._decay_var_list:
return var.assign_sub(self._weight_decay * var, self._use_locking)
return control_flow_ops.no_op()
def _decay_weights_sparse_op(self, var, indices, scatter_add):
if not self._decay_var_list or var in self._decay_var_list:
return scatter_add(var, indices, -self._weight_decay * var,
self._use_locking)
return control_flow_ops.no_op()
# Here, we overwrite the apply functions that the base optimizer calls.
# super().apply_x resolves to the apply_x function of the BaseOptimizer.
def _apply_dense(self, grad, var):
with ops.control_dependencies([self._decay_weights_op(var)]):
return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var)
def _resource_apply_dense(self, grad, var):
with ops.control_dependencies([self._decay_weights_op(var)]):
return super(DecoupledWeightDecayExtension, self)._resource_apply_dense(
grad, var)
def _apply_sparse(self, grad, var):
scatter_add = state_ops.scatter_add
decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add)
with ops.control_dependencies([decay_op]):
return super(DecoupledWeightDecayExtension, self)._apply_sparse(
grad, var)
def _resource_scatter_add(self, x, i, v, _=None):
# last argument allows for one overflow argument, to have the same function
# signature as state_ops.scatter_add
with ops.control_dependencies(
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
return x.value()
def _resource_apply_sparse(self, grad, var, indices):
scatter_add = self._resource_scatter_add
decay_op = self._decay_weights_sparse_op(var, indices, scatter_add)
with ops.control_dependencies([decay_op]):
return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse(
grad, var, indices)
def extend_with_decoupled_weight_decay(base_optimizer):
"""Factory function returning an optimizer class with decoupled weight decay.
Returns an optimizer class. An instance of the returned class computes the
update step of `base_optimizer` and additionally decays the weights.
E.g., the class returned by
`extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to
`tf.contrib.opt.AdamWOptimizer`.
The API of the new optimizer class slightly differs from the API of the
base optimizer:
- The first argument to the constructor is the weight decay rate.
- `minimize` and `apply_gradients` accept the optional keyword argument
`decay_var_list`, which specifies the variables that should be decayed.
If `None`, all variables that are optimized are decayed.
Usage example:
```python
# MyAdamW is a new class
MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)
# Create a MyAdamW object
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
sess.run(optimizer.minimize(loss, decay_variables=[var1, var2]))
Note that this extension decays weights BEFORE applying the update based
on the gradient, i.e. this extension only has the desired behaviour for
optimizers which do not depend on the value of'var' in the update step!
```
Args:
base_optimizer: An optimizer class that inherits from tf.train.Optimizer.
Returns:
A new optimizer class that inherits from DecoupledWeightDecayExtension
and base_optimizer.
"""
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
base_optimizer):
"""Base_optimizer with decoupled weight decay.
This class computes the update step of `base_optimizer` and
additionally decays the variable with the weight decay being decoupled from
the optimization steps w.r.t. to the loss function, as described by
Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf).
For SGD variants, this simplifies hyperparameter search since
it decouples the settings of weight decay and learning rate.
For adaptive gradient algorithms, it regularizes variables with large
gradients more than L2 regularization would, which was shown to yield
better training loss and generalization error in the paper above.
"""
def __init__(self, weight_decay, *args, **kwargs):
# super delegation is necessary here
# pylint: disable=useless-super-delegation
super(OptimizerWithDecoupledWeightDecay, self).__init__(
weight_decay, *args, **kwargs)
# pylint: enable=useless-super-delegation
return OptimizerWithDecoupledWeightDecay
@tf_export("contrib.opt.MomentumWOptimizer")
class MomentumWOptimizer(DecoupledWeightDecayExtension,
momentum_opt.MomentumOptimizer):
"""Optimizer that implements the Momentum algorithm with weight_decay.
This is an implementation of the SGDW optimizer described in "Fixing
Weight Decay Regularization in Adam" by Loshchilov & Hutter
(https://arxiv.org/abs/1711.05101)
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
It computes the update step of `train.MomentumOptimizer` and additionally
decays the variable. Note that this is different from adding
L2 regularization on the variables to the loss. Decoupling the weight decay
from other hyperparameters (in particular the learning rate) simplifies
hyperparameter search.
For further information see the documentation of the Momentum Optimizer.
Note that this optimizer can also be instantiated as
```python
extend_with_weight_decay(tf.train.MomentumOptimizer,
weight_decay=weight_decay)
```
"""
def __init__(self, weight_decay, learning_rate, momentum,
use_locking=False, name="MomentumW", use_nesterov=False):
"""Construct a new MomentumW optimizer.
For further information see the documentation of the Momentum Optimizer.
Args:
weight_decay: A `Tensor` or a floating point value. The weight decay.
learning_rate: A `Tensor` or a floating point value. The learning rate.
momentum: A `Tensor` or a floating point value. The momentum.
use_locking: If `True` use locks for update operations.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Momentum".
use_nesterov: If `True` use Nesterov Momentum.
See [Sutskever et al., 2013](
http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
This implementation always computes gradients at the value of the
variable(s) passed to the optimizer. Using Nesterov Momentum makes the
variable(s) track the values called `theta_t + mu*v_t` in the paper.
@compatibility(eager)
When eager execution is enabled, learning_rate, weight_decay and momentum
can each be a callable that takes no arguments and returns the actual value
to use. This can be useful for changing these values across different
invocations of optimizer functions.
@end_compatibility
"""
super(MomentumWOptimizer, self).__init__(
weight_decay, learning_rate=learning_rate, momentum=momentum,
use_locking=use_locking, name=name, use_nesterov=use_nesterov)
@tf_export("contrib.opt.AdamWOptimizer")
class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
"""Optimizer that implements the Adam algorithm with weight decay.
This is an implementation of the AdamW optimizer described in "Fixing
Weight Decay Regularization in Adam" by Loshchilov & Hutter
(https://arxiv.org/abs/1711.05101)
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
It computes the update step of `train.AdamOptimizer` and additionally decays
the variable. Note that this is different from adding L2 regularization on
the variables to the loss: it regularizes variables with large
gradients more than L2 regularization would, which was shown to yield better
training loss and generalization error in the paper above.
For further information see the documentation of the Adam Optimizer.
Note that this optimizer can also be instantiated as
```python
extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay)
```
"""
def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999,
epsilon=1e-8, use_locking=False, name="AdamW"):
"""Construct a new AdamW optimizer.
For further information see the documentation of the Adam Optimizer.
Args:
weight_decay: A `Tensor` or a floating point value. The weight decay.
learning_rate: A Tensor or a floating point value. The learning rate.
beta1: A float value or a constant float tensor.
The exponential decay rate for the 1st moment estimates.
beta2: A float value or a constant float tensor.
The exponential decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper.
use_locking: If True use locks for update operations.
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
"""
super(AdamWOptimizer, self).__init__(
weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2,
epsilon=epsilon, use_locking=use_locking, name=name)

View File

@ -0,0 +1,188 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optimizers with weight decay."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.opt.python.training import weight_decay_optimizers
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
WEIGHT_DECAY = 0.01
def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9,
beta2=0.999, epsilon=1e-8):
lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t)
m_t = beta1 * m + (1 - beta1) * g_t
v_t = beta2 * v + (1 - beta2) * g_t * g_t
param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) -
(param * WEIGHT_DECAY))
return param_t, m_t, v_t
def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_):
# v, t are not needed for momentum optimizer
m = momentum * m + g_t
param_t = param - lr * m - param * WEIGHT_DECAY
return param_t, m, None
class WeightDecayOptimizerTest(test.TestCase):
def doTest(self, optimizer, update_fn, optimizer_name, slot_name,
use_resource=False, do_sparse=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
with self.test_session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
if use_resource:
var0 = resource_variable_ops.ResourceVariable(
var0_np, name="var0_%d" % i)
var1 = resource_variable_ops.ResourceVariable(
var1_np, name="var1_%d" % i)
else:
var0 = variables.Variable(var0_np)
var1 = variables.Variable(var1_np)
if do_sparse:
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(constant_op.constant(grads0_np),
constant_op.constant(grads0_np_indices),
constant_op.constant([2]))
grads1_np_indices = np.array([0, 1], dtype=np.int32)
grads1 = ops.IndexedSlices(constant_op.constant(grads1_np),
constant_op.constant(grads1_np_indices),
constant_op.constant([2]))
else:
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
opt = optimizer()
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
if not context.executing_eagerly():
with ops.Graph().as_default():
# Shouldn't return non-slot variables from other graphs.
self.assertEqual(0, len(opt.variables()))
self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of the optimizer
for t in range(1, 4):
if not context.executing_eagerly():
self.evaluate(update)
elif t > 1:
opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0)
var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1)
# Validate updated params
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
if use_resource:
self.assertEqual("var0_%d/%s:0" % (i, optimizer_name),
opt.get_slot(var=var0, name=slot_name).name)
class AdamWOptimizerTest(WeightDecayOptimizerTest):
@staticmethod
def get_optimizer():
return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY)
def testSparse(self):
self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
use_resource=False, do_sparse=True)
def testResourceSparse(self):
self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
use_resource=True, do_sparse=True)
def testBasic(self):
self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
def testResourceBasic(self):
self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
use_resource=True)
class MomentumWOptimizerTest(WeightDecayOptimizerTest):
@staticmethod
def get_optimizer():
return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9)
def testSparse(self):
self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
"momentum", use_resource=False, do_sparse=True)
def testResourceSparse(self):
self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
"momentum", use_resource=True, do_sparse=True)
def testBasic(self):
self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
"momentum", use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
def testResourceBasic(self):
self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
"momentum", use_resource=True)
class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
@staticmethod
def get_optimizer():
adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay(
adam.AdamOptimizer)
return adamw(WEIGHT_DECAY)
def testBasic(self):
self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
def testResourceBasic(self):
self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
use_resource=True)
if __name__ == "__main__":
test.main()

View File

@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import linalg_ops
def conjugate_gradient(operator, def conjugate_gradient(operator,

View File

@ -49,7 +49,6 @@ tf_cuda_cc_test(
tf_custom_op_library( tf_custom_op_library(
name = "python/ops/_trt_engine_op.so", name = "python/ops/_trt_engine_op.so",
srcs = [ srcs = [
"ops/trt_calib_op.cc",
"ops/trt_engine_op.cc", "ops/trt_engine_op.cc",
], ],
deps = [ deps = [
@ -76,11 +75,9 @@ tf_cuda_library(
cc_library( cc_library(
name = "trt_engine_op_kernel", name = "trt_engine_op_kernel",
srcs = [ srcs = [
"kernels/trt_calib_op.cc",
"kernels/trt_engine_op.cc", "kernels/trt_engine_op.cc",
], ],
hdrs = [ hdrs = [
"kernels/trt_calib_op.h",
"kernels/trt_engine_op.h", "kernels/trt_engine_op.h",
], ],
copts = tf_copts(), copts = tf_copts(),
@ -89,20 +86,22 @@ cc_library(
":trt_logging", ":trt_logging",
":trt_plugins", ":trt_plugins",
":trt_resources", ":trt_resources",
":trt_conversion",
":utils",
"//tensorflow/core:gpu_headers_lib", "//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:stream_executor_headers_lib", "//tensorflow/core:stream_executor_headers_lib",
"//tensorflow/core/grappler/costs:graph_properties",
] + if_tensorrt([ ] + if_tensorrt([
"@local_config_tensorrt//:nv_infer", "@local_config_tensorrt//:nv_infer",
]) + tf_custom_op_library_additional_deps(), ]) + tf_custom_op_library_additional_deps(),
# TODO(laigd) # TODO(laigd): fix this by merging header file in cc file.
alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs
) )
tf_gen_op_libs( tf_gen_op_libs(
op_lib_names = [ op_lib_names = [
"trt_engine_op", "trt_engine_op",
"trt_calib_op",
], ],
) )
@ -122,7 +121,6 @@ tf_gen_op_wrapper_py(
name = "trt_engine_op", name = "trt_engine_op",
gen_locally = True, gen_locally = True,
deps = [ deps = [
":trt_calib_op_op_lib",
":trt_engine_op_op_lib", ":trt_engine_op_op_lib",
":trt_logging", ":trt_logging",
":trt_shape_function", ":trt_shape_function",
@ -140,7 +138,6 @@ tf_custom_op_py_library(
kernels = [ kernels = [
":trt_engine_op_kernel", ":trt_engine_op_kernel",
":trt_engine_op_op_lib", ":trt_engine_op_op_lib",
":trt_calib_op_op_lib",
":trt_shape_function", ":trt_shape_function",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
@ -191,7 +188,6 @@ tf_py_wrap_cc(
deps = [ deps = [
":trt_conversion", ":trt_conversion",
":trt_engine_op_kernel", ":trt_engine_op_kernel",
"//tensorflow/core:framework_lite",
"//third_party/python_runtime:headers", "//third_party/python_runtime:headers",
], ],
) )
@ -211,6 +207,7 @@ tf_cuda_library(
], ],
deps = [ deps = [
":trt_logging", ":trt_logging",
":utils",
"//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_headers_lib",
"//tensorflow/core:framework_lite", "//tensorflow/core:framework_lite",
"//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib_proto_parsing",
@ -237,12 +234,12 @@ tf_cuda_library(
":trt_plugins", ":trt_plugins",
":trt_logging", ":trt_logging",
":trt_resources", ":trt_resources",
":utils",
"//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime", "//tensorflow/core:gpu_runtime",
"//tensorflow/core:framework_lite", "//tensorflow/core:framework_lite",
"//tensorflow/core:graph", "//tensorflow/core:graph",
@ -343,3 +340,8 @@ py_test(
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
], ],
) )
cc_library(
name = "utils",
hdrs = ["convert/utils.h"],
)

File diff suppressed because it is too large Load Diff

View File

@ -30,29 +30,60 @@ namespace tensorflow {
namespace tensorrt { namespace tensorrt {
namespace convert { namespace convert {
// This method converts an already generated calibration graph which was used in struct ConversionParams {
// calibration runs to an inference graph ConversionParams()
tensorflow::Status ConvertCalibGraphToInferGraph( : input_graph_def(nullptr),
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def); max_batch_size(1),
max_workspace_size_bytes(1 << 30),
output_graph_def(nullptr),
precision_mode(1),
minimum_segment_size(3),
graph_properties(nullptr),
cluster(nullptr),
is_dyn_op(false),
fixed_input_size(true),
max_cached_engines(1) {}
const tensorflow::GraphDef* input_graph_def;
const std::vector<string>* output_names;
size_t max_batch_size;
size_t max_workspace_size_bytes;
tensorflow::GraphDef* output_graph_def;
int precision_mode;
int minimum_segment_size;
const tensorflow::grappler::GraphProperties* graph_properties;
const tensorflow::grappler::Cluster* cluster;
bool is_dyn_op; // Whether to create engine on conversion or execution time
bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed
int max_cached_engines; // maximum number of cached engines
std::vector<int> cached_engine_batches; // list of cached engines
};
// max_batch_size: maximum batch size which can be used for inference for // This method extracts calibration information from the resource managers
// optimization targets inference run with max batch size. // and puts them in to engine nodedefs.
// max_workspace_size_bytes: The upper bound of memory allowance for tensorflow::Status ConvertCalibGraphToInferGraph(
// engine building. const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def,
bool is_dyn_op);
// - max_batch_size: maximum batch size which can be used for inference for
// optimization targets inference run with max batch size.
// - max_workspace_size_bytes: The upper bound of memory allowance for engine
// building.
tensorflow::Status ConvertGraphDefToTensorRT( tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def, const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size, const std::vector<string>& output_names, size_t max_batch_size,
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
int precision_mode, int minimum_segment_size); int precision_mode = 1, int minimum_segment_size = 3,
bool is_dyn_op = false, int max_cached_engines = 1,
std::vector<int> cached_engine_batches = {});
// Method to call from optimization pass // Method to call from optimization pass
tensorflow::Status ConvertAfterShapes( tensorflow::Status ConvertAfterShapes(ConversionParams& params);
const tensorflow::GraphDef& graph, const std::vector<string>& output_names,
size_t max_batch_size, size_t max_workspace_size_bytes, // Return compile time TensorRT library version information.
tensorflow::GraphDef* new_graph_def, int precision_mode, std::vector<int> GetLinkedTensorRTVersion();
int minimum_segment_size,
const tensorflow::grappler::GraphProperties& graph_properties, // Return runtime time TensorRT library version information.
const tensorflow::grappler::Cluster* cluster); std::vector<int> GetLoadedTensorRTVersion();
} // namespace convert } // namespace convert
} // namespace tensorrt } // namespace tensorrt
} // namespace tensorflow } // namespace tensorflow

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <algorithm> #include <algorithm>
#include <list> #include <list>
@ -25,7 +24,9 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def.pb.h" // NOLINT
@ -37,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -54,8 +56,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tensorrt { namespace tensorrt {
namespace convert { namespace convert {
using ::tensorflow::str_util::Split;
using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat; using ::tensorflow::strings::StrCat;
namespace { namespace {
inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
@ -121,12 +126,10 @@ static std::vector<std::pair<int, int>> CreateSamePadding(
string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
size_t last_scope_separator = 0; size_t last_scope_separator = 0;
for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) { const size_t min_size = std::min(op_name_a.size(), op_name_b.size());
if (op_name_a[i] != op_name_b[i]) { for (size_t i = 0; i < min_size; ++i) {
break; if (op_name_a[i] != op_name_b[i]) break;
} else if (op_name_a[i] == '/') { if (op_name_a[i] == '/') last_scope_separator = i + 1;
last_scope_separator = i + 1;
}
} }
return op_name_a.substr(0, last_scope_separator); return op_name_a.substr(0, last_scope_separator);
} }
@ -417,20 +420,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
} }
} }
struct InferDeleter {
template <typename T>
void operator()(T* obj) const {
if (obj) {
obj->destroy();
}
}
};
template <typename T>
inline std::shared_ptr<T> infer_object(T* obj) {
return std::shared_ptr<T>(obj, InferDeleter());
}
class Converter; class Converter;
using OpConverter = using OpConverter =
@ -444,7 +433,7 @@ class Converter {
OpConverter plugin_converter_; OpConverter plugin_converter_;
nvinfer1::INetworkDefinition* trt_network_; nvinfer1::INetworkDefinition* trt_network_;
std::list<std::vector<uint8_t>> temp_bufs_; std::list<std::vector<uint8_t>> temp_bufs_;
tensorflow::tensorrt::TRTWeightStore* weight_store_; TRTWeightStore* weight_store_;
bool fp16_; bool fp16_;
void register_op_converters(); void register_op_converters();
tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def, tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
@ -486,11 +475,11 @@ class Converter {
public: public:
explicit Converter(nvinfer1::INetworkDefinition* trt_network, explicit Converter(nvinfer1::INetworkDefinition* trt_network,
tensorflow::tensorrt::TRTWeightStore* ws, bool fp16) TRTWeightStore* ws, bool fp16)
: trt_network_(trt_network), weight_store_(ws), fp16_(fp16) { : trt_network_(trt_network), weight_store_(ws), fp16_(fp16) {
this->register_op_converters(); this->register_op_converters();
} }
tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; } TRTWeightStore* weight_store() { return weight_store_; }
TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
nvinfer1::Dims shape) { nvinfer1::Dims shape) {
TRT_ShapedWeights weights(type, nullptr, shape); TRT_ShapedWeights weights(type, nullptr, shape);
@ -2140,559 +2129,265 @@ void Converter::register_op_converters() {
} // namespace } // namespace
tensorflow::Status ConvertCalibrationNodeToEngineNode( tensorflow::Status ConvertGraphDefToEngine(
tensorflow::Graph& graph, tensorflow::Node* c_node) { const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size,
const auto ndef = c_node->def(); size_t max_workspace_size_bytes,
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
Logger* logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
bool* convert_successfully) {
engine->reset();
if (convert_successfully) *convert_successfully = false;
TFAttrs attrs(ndef); // Create the builder.
std::vector<string> segment_nodes( TrtUniquePtrType<nvinfer1::IBuilder> builder(
attrs.get<std::vector<string>>("segment_nodes")); nvinfer1::createInferBuilder(*logger));
std::vector<string> output_nodes( builder->setMaxBatchSize(max_batch_size);
attrs.get<std::vector<string>>("segment_output_names")); // TODO(aaroey): use the allocator to allocate the TRT workspace.
std::vector<string> input_names( builder->setMaxWorkspaceSize(max_workspace_size_bytes);
attrs.get<std::vector<string>>("input_names"));
string res_name = attrs.get<string>("resource_name");
VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name;
string engine_name = "my_trt_op";
{
const auto node_id = tensorflow::str_util::Split(res_name, "_");
engine_name += node_id.back();
}
std::map<string, tensorflow::Node*> node_maps;
for (auto n : graph.op_nodes()) {
node_maps.insert({n->name(), n});
}
std::set<int> subgraph_ids;
for (const auto internal_node : segment_nodes) {
subgraph_ids.insert(node_maps.at(internal_node)->id());
}
if (VLOG_IS_ON(2)) {
string node_names = StrCat(c_node->name(), " segment nodes= ");
for (const auto& node_name : segment_nodes) {
StrAppend(&node_names, node_name, ", ");
}
VLOG(2) << node_names;
}
VLOG(1) << "Output Nodes:";
std::vector<tensorflow::DataType> out_types;
std::vector<const tensorflow::Edge*> out_edges;
for (auto& i : output_nodes) {
auto node_port = tensorflow::str_util::Split(i, ":");
VLOG(1) << " " << i << " in graph " << node_maps.count(i);
auto out_node_name = node_port.at(0);
if (node_port.size() > 1) {
VLOG(1) << "Multi port output" << node_port.at(0) << " "
<< node_port.at(1) << " size=" << node_port.size();
}
auto node_it = node_maps.find(out_node_name);
if (node_it != node_maps.end()) {
tensorflow::Node* out_node = node_it->second;
int port = 0;
if (node_port.size() == 2) {
port = std::strtoul(node_port.at(1).c_str(), nullptr, 10);
out_types.push_back(out_node->output_type(port));
} else {
out_types.push_back(out_node->output_type(0));
}
for (auto out_edge : out_node->out_edges()) {
if (subgraph_ids.count(out_edge->dst()->id()))
continue; // skip internal edges;
if (out_edge->src_output() == port) {
out_edges.push_back(out_edge);
VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":"
<< out_edge->src_output() << " -> " << out_edge->dst()->name()
<< ":" << out_edge->dst_input();
}
}
} else {
LOG(WARNING) << " couldn't find output node " << out_node_name;
}
}
if (VLOG_IS_ON(1)) {
VLOG(1) << c_node->name() << " Input Nodes:";
for (auto& i : input_names) {
VLOG(1) << " Input " << i << " in graph " << node_maps.count(i);
}
}
auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
auto resmgr = trt_rm->getManager("TRTCalibOps");
tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr;
auto status = resmgr->Lookup(res_name, res_name, &calib_res);
if (!status.ok() || !calib_res->calibrator_) {
return tensorflow::errors::FailedPrecondition(
"You must run calibration"
" and inference conversion in the same process");
}
calib_res->calibrator_->setDone();
calib_res->thr_->join();
delete calib_res->thr_;
if (!calib_res->engine_) {
LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run "
"calibration graph?";
return tensorflow::errors::FailedPrecondition(
"Calibration graph needs to be executed on"
" calibration data before convertsion to inference graph");
}
auto weight_rmgr = trt_rm->getManager("WeightStore");
TF_CHECK_OK(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(
res_name, res_name));
auto engine_plan = calib_res->engine_->serialize();
calib_res->engine_->destroy();
calib_res->network_->destroy();
calib_res->builder_->destroy();
calib_res->thr_ = nullptr;
calib_res->engine_ = nullptr;
calib_res->builder_ = nullptr;
tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
income_edges.resize(c_node->num_inputs());
for (const auto in_edge : c_node->in_edges()) {
auto src = in_edge->src();
int dest_port = in_edge->dst_input();
VLOG(1) << "Incoming connection " << src->name() << ":"
<< in_edge->src_output() << " -> " << c_node->name() << ":"
<< dest_port;
income_edges.at(dest_port) = {src->name(), in_edge->src_output(),
c_node->input_type(dest_port)};
}
tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
income_edges);
if (VLOG_IS_ON(2)) {
for (const auto& inp : input_list) {
VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " "
<< tensorflow::DataTypeString(inp.data_type);
}
}
op_builder.Input(input_list);
tensorflow::NodeDef engine_node;
const char* engine_plan_data = static_cast<const char*>(engine_plan->data());
string engine_plan_string(engine_plan_data,
engine_plan_data + engine_plan->size());
status = op_builder.Attr("serialized_engine", engine_plan_string)
.Attr("input_nodes", input_names)
.Attr("output_nodes", output_nodes)
.Attr("OutT", out_types)
.Finalize(&engine_node);
if (!status.ok()) {
LOG(ERROR) << "Engine Node creation failed";
return status;
}
auto trt_engine_node = graph.AddNode(engine_node, &status);
TF_RETURN_IF_ERROR(status);
std::map<string, int> port_map;
for (size_t t = 0; t < output_nodes.size(); t++) {
port_map.insert({output_nodes.at(t), t});
}
for (auto& i : out_edges) {
string s(i->src()->name());
if (i->src_output()) StrAppend(&s, ":", i->src_output());
int out_port = port_map.at(s);
VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port
<< " -> " << i->dst()->name() << ":" << i->dst_input();
TF_RETURN_IF_ERROR(
graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input()));
}
for (const auto ed : trt_engine_node->in_edges()) {
VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output()
<< " -> " << ed->dst()->name() << ":" << ed->dst_input();
}
for (const auto ed : trt_engine_node->out_edges()) {
VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output()
<< " -> " << ed->dst()->name() << ":" << ed->dst_input();
}
VLOG(1) << "Segment nodes:";
for (auto& i : segment_nodes) {
VLOG(1) << " " << i << " in graph " << node_maps.count(i);
auto it = node_maps.find(i);
if (it != node_maps.end()) {
graph.RemoveNode(it->second);
}
}
graph.RemoveNode(c_node);
return tensorflow::Status::OK();
}
tensorflow::Status ReverseTopologicalSort(
const tensorrt::convert::SubGraphParams& s,
std::list<tensorflow::Node*>* order) {
std::vector<tensorflow::Node*> order_vec;
tensorflow::GetPostOrder(s.graph, &order_vec);
// Select just the subgraph
for (tensorflow::Node* node : order_vec) {
if (s.subgraph_node_ids.count(node->id())) {
// We want topological order to contstruct the
// network layer by layer
order->push_front(node);
}
}
return tensorflow::Status::OK();
}
tensorflow::Status SetInputList(
const tensorrt::convert::SubGraphParams& s,
tensorflow::NodeDefBuilder* op_builder,
const std::vector<string>* input_names,
std::vector<tensorflow::DataType>* input_dtypes) {
std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
VLOG(2) << "input edge size: " << input_names->size();
for (size_t i = 0; i < input_names->size(); ++i) {
VLOG(2) << "input edges: " << i << " " << input_names->at(i);
int output_idx = s.input_inds.at(i).second;
// we wired up the input here already, it is redundant to do it again in
// ConvertSubGraphToTensorRT(convert_graph.cc)
auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
input_names->at(i), output_idx, input_dtypes->at(i));
income_edges.push_back(incoming_edge);
}
tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
income_edges);
op_builder->Input(input_list);
return tensorflow::Status::OK();
}
string SubgraphNameScopeGenerator(const std::list<tensorflow::Node*>* order) {
string subgraph_name_scope;
if (!order->empty()) {
subgraph_name_scope = order->front()->name();
}
for (const tensorflow::Node* node : *order) {
subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name());
}
// TODO(sami,ben,jie): proper naming!
return subgraph_name_scope;
}
tensorflow::Status ConvertSubgraph(
Converter& converter, tensorrt::convert::SubGraphParams& s,
std::list<tensorflow::Node*>* order, std::vector<string>* input_names,
std::vector<tensorflow::DataType>* input_dtypes,
std::vector<string>* output_names,
std::vector<tensorflow::DataType>* output_dtypes,
const string& engine_name) {
std::set<string> added_tensors;
for (const std::pair<int, int>& input : s.input_inds) {
VLOG(2) << "parsing input. Node id= " << input.first;
int node_id = input.first;
int output_idx = input.second;
tensorflow::Node* node = s.graph.FindNodeId(node_id);
auto node_name = node->name();
// input_names should use the node name in the graph
// here it should be the input tensor name -> matching the binding
// insert original node name without port
auto tensor_name = node_name;
if (output_idx != 0) {
tensor_name = StrCat(tensor_name, ":", output_idx);
}
VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name
<< " idx: " << output_idx;
auto shape_inference_node_name = node_name;
auto shape_inference_output_idx = output_idx;
// rewire the shape inference to original node in the graph
if (s.output_edge_map->count(tensor_name)) {
shape_inference_node_name = s.output_edge_map->at(tensor_name).second;
shape_inference_output_idx = s.output_edge_map->at(tensor_name).first;
}
if (shape_inference_output_idx < 0) continue;
VLOG(2) << "shapeinference name: " << shape_inference_node_name
<< " idx: " << shape_inference_output_idx;
if (!s.graph_properties.HasOutputProperties(shape_inference_node_name))
return tensorflow::errors::Internal("failed to find input node: " +
shape_inference_node_name);
auto op_info_vec =
s.graph_properties.GetOutputProperties(shape_inference_node_name);
if (static_cast<int>(op_info_vec.size()) <= shape_inference_output_idx)
return tensorflow::errors::Internal(
"accessing output index of: ", shape_inference_output_idx,
", at node: ", shape_inference_node_name,
" with output entry from shape_map: ", op_info_vec.size());
auto op_info = op_info_vec.at(shape_inference_output_idx);
tensorflow::DataType tf_dtype = op_info.dtype();
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
auto type_status = ConvertDType(tf_dtype, &dtype);
if (type_status != tensorflow::Status::OK()) {
LOG(WARNING) << "Type conversion failed for " << node_name;
return type_status;
}
VLOG(2) << "Accessing output index of: " << output_idx
<< ", at node: " << node_name
<< " with output entry from shape_map: " << op_info_vec.size();
// TODO(ben,jie): update TRT input format/dimension
nvinfer1::DimsCHW input_dim_pseudo_chw;
for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1;
// TODO(jie): TRT 3.x only support 4 dimensional input tensor.
// update the code once TRT 4.0 comes out.
if (op_info.shape().dim_size() != 4) {
string err_str = "Require 4 dimensional input.";
StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ",
shape_inference_node_name);
return tensorflow::errors::Unimplemented(err_str);
}
for (int i = 1; i < op_info.shape().dim_size(); i++) {
VLOG(2) << "dimension: " << i
<< " , size: " << op_info.shape().dim(i).size();
input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size();
}
// TODO(ben,jie): proper way to restore input tensor name?
auto input_tensor_name = node_name;
if (output_idx != 0) {
input_tensor_name = StrCat(node_name, ":", output_idx);
}
if (added_tensors.count(input_tensor_name)) continue;
added_tensors.insert(input_tensor_name);
input_names->push_back(input_tensor_name);
input_dtypes->push_back(tf_dtype);
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
if (!input_tensor)
return tensorflow::errors::InvalidArgument(
"Failed to create Input layer");
VLOG(2) << "Input tensor name :" << input_tensor_name;
if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
return tensorflow::errors::AlreadyExists(
"Output tensor already exists for op: " + input_tensor_name);
}
for (const tensorflow::Node* node : *order) {
const tensorflow::NodeDef& node_def = node->def();
VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
TF_RETURN_IF_ERROR(converter.convert_node(node_def));
}
VLOG(2) << "Finished conversion";
// Gather output metadata
int trt_engine_op_output_idx = 0;
added_tensors.clear();
for (const std::pair<int, int>& output : s.output_inds) {
int node_id = output.first;
int output_idx = output.second;
tensorflow::Node* node = s.graph.FindNodeId(node_id);
string op_name = node->name();
string tensor_name = op_name;
s.output_edge_map->insert(
{trt_engine_op_output_idx == 0
? engine_name
: StrCat(engine_name, ":", trt_engine_op_output_idx),
{output_idx, tensor_name}});
trt_engine_op_output_idx++;
if (output_idx != 0)
tensorflow::strings::StrAppend(&tensor_name, ":", output_idx);
VLOG(2) << "Output tensor name: " << tensor_name;
if (added_tensors.count(tensor_name)) continue;
added_tensors.insert(tensor_name);
output_names->push_back(tensor_name);
auto tensor_or_weights = converter.get_tensor(tensor_name);
if (!tensor_or_weights.is_tensor()) {
return tensorflow::errors::InvalidArgument("Output node '" + tensor_name +
"' is weights not tensor");
}
nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
if (!tensor) {
return tensorflow::errors::NotFound("Output tensor not found: " +
tensor_name);
}
converter.network()->markOutput(*tensor);
tensorflow::DataType tf_dtype = node->output_type(output_idx);
output_dtypes->push_back(tf_dtype);
nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
tensor->setType(trt_dtype);
}
return tensorflow::Status::OK();
}
tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
// Visit nodes in reverse topological order and construct the TRT network.
// Toposort
std::list<tensorflow::Node*> order;
TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order));
static int static_id = 0;
string subgraph_name_scope = SubgraphNameScopeGenerator(&order);
// TODO(sami,ben,jie): proper naming!
string calib_op_name =
StrCat(subgraph_name_scope, "my_trt_calib_op_", static_id);
string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id);
static_id++;
auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance();
auto op_rmgr = trt_rmgr->getManager("TRTCalibOps");
auto op_res = new tensorflow::tensorrt::TRTCalibrationResource();
TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res));
op_res->logger_ = new tensorflow::tensorrt::Logger();
cudaSetDevice(s.cuda_gpu_id_);
op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_));
op_res->allocator_ = s.allocator_;
#if NV_TENSORRT_MAJOR > 3 #if NV_TENSORRT_MAJOR > 3
op_res->builder_->setGpuAllocator(s.allocator_.get()); builder->setGpuAllocator(allocator);
#endif #endif
if (!op_res->builder_) { if (precision_mode == FP16MODE) {
return tensorflow::errors::Internal( builder->setHalf2Mode(true);
"failed to create TensorRT builder object"); } else if (precision_mode == INT8MODE) {
builder->setInt8Mode(true);
builder->setInt8Calibrator(calibrator);
} }
op_res->network_ = op_res->builder_->createNetwork(); // Create the network.
if (!op_res->network_) { auto trt_network =
return tensorflow::errors::Internal( TrtUniquePtrType<nvinfer1::INetworkDefinition>(builder->createNetwork());
"failed to create TensorRT network object");
}
// Build the network
auto weight_rmgr = trt_rmgr->getManager("WeightStore");
auto ws = new tensorflow::tensorrt::TRTWeightStore();
TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws));
Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE);
std::vector<string> input_names;
std::vector<tensorflow::DataType> input_dtypes;
std::vector<string> output_names;
std::vector<tensorflow::DataType> output_dtypes;
TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names,
&input_dtypes, &output_names,
&output_dtypes, engine_name));
VLOG(2) << "Finished processing outputs";
// Build the engine
op_res->builder_->setMaxBatchSize(s.max_batch_size);
op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes);
VLOG(0) << "Max batch size= " << s.max_batch_size
<< " max workspace size= " << s.max_workspace_size_bytes;
// Build the TRT op
// TODO(sami,ben,jie): proper naming!
tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp");
TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes));
std::vector<string> segment_names;
segment_names.reserve(s.subgraph_node_ids.size());
for (int i : s.subgraph_node_ids) {
auto node = s.graph.FindNodeId(i);
segment_names.push_back(node->name());
}
LOG(INFO) << "finished op preparation";
auto status = op_builder.Attr("segment_nodes", segment_names)
.Attr("input_names", input_names)
.Attr("segment_output_names", output_names)
.Attr("resource_name", calib_op_name)
.Finalize(s.trt_node);
LOG(INFO) << status.ToString();
LOG(INFO) << "finished op building";
return tensorflow::Status::OK();
}
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
tensorrt::convert::SubGraphParams& s) {
// Visit nodes in reverse topological order and construct the TRT network.
std::list<tensorflow::Node*> order;
TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order));
static int static_id = 0;
string subgraph_name_scope = SubgraphNameScopeGenerator(&order);
string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id++);
tensorflow::tensorrt::Logger trt_logger;
cudaSetDevice(s.cuda_gpu_id_);
auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
if (!trt_builder) {
return tensorflow::errors::Internal(
"Failed to create TensorRT builder object");
}
#if NV_TENSORRT_MAJOR > 3
trt_builder->setGpuAllocator(s.allocator_.get());
#endif
auto trt_network = infer_object(trt_builder->createNetwork());
if (!trt_network) { if (!trt_network) {
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
"Failed to create TensorRT network object"); "Failed to create TensorRT network object");
} }
auto ws = std::unique_ptr<TRTWeightStore>(new TRTWeightStore());
auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance();
auto weight_rmgr = trt_rmgr->getManager("WeightStore");
auto ws = new tensorflow::tensorrt::TRTWeightStore();
TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws));
// Build the network // Build the network
Converter converter(trt_network.get(), ws, s.precision_mode == FP16MODE); VLOG(1) << "Starting engine conversion ";
Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE);
std::vector<std::pair<string, string>> output_tensors;
// Graph nodes are already topologically sorted during construction
for (const auto& node_def : gdef.node()) {
string node_name = node_def.name();
VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op();
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
(node_def.op() == "Placeholder")) {
nvinfer1::DimsCHW input_dim_pseudo_chw;
for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
auto type_status =
ConvertDType(node_def.attr().at("dtype").type(), &dtype);
if (type_status != tensorflow::Status::OK()) {
LOG(WARNING) << "Type conversion failed for " << node_name;
return type_status;
}
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8,
&slot_number)) {
LOG(ERROR) << "Failed to parse slot number from " << node_name
<< " +8= " << node_name.c_str() + 8;
}
auto shape = input_shapes.at(slot_number);
if (shape.dims() > 8) {
LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name
<< " at input slot " << slot_number;
return tensorflow::errors::OutOfRange(
"Input tensor rank is greater than 8");
}
if (VLOG_IS_ON(1)) {
string dim_str("dims=");
StrAppend(&dim_str, "[ ", shape.dim_size(0));
for (int i = 1; i < shape.dims(); i++) {
StrAppend(&dim_str, ", ", shape.dim_size(i));
}
StrAppend(&dim_str, " ]");
VLOG(1) << dim_str;
}
for (int i = 1; i < shape.dims(); i++) {
input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i);
}
std::vector<string> input_names; input_dim_pseudo_chw.nbDims = shape.dims() - 1;
std::vector<tensorflow::DataType> input_dtypes; nvinfer1::ITensor* input_tensor = converter.network()->addInput(
std::vector<string> output_names; node_name.c_str(), dtype, input_dim_pseudo_chw);
std::vector<tensorflow::DataType> output_dtypes; if (!input_tensor) {
TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, return tensorflow::errors::InvalidArgument(
&input_dtypes, &output_names, "Failed to create Input layer tensor ", node_name,
&output_dtypes, engine_name)); " rank=", shape.dims() - 1);
}
VLOG(2) << "Finished output"; VLOG(1) << "Input tensor name :" << node_name;
if (!converter.insert_input_tensor(node_name, input_tensor)) {
// Build the engine return tensorflow::errors::AlreadyExists(
trt_builder->setMaxBatchSize(s.max_batch_size); "Output tensor already exists for op: " + node_name);
trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes); }
VLOG(0) << "Max batch size= " << s.max_batch_size } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
<< " max workspace size= " << s.max_workspace_size_bytes; (node_def.op() == "Identity")) {
if (s.precision_mode == FP16MODE) { int32 slot_number = -1;
trt_builder->setHalf2Mode(true); if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9,
VLOG(0) << "Using FP16 precision mode"; &slot_number)) {
} LOG(ERROR) << "Failed to parse slot number from " << node_name
LOG(INFO) << "starting build engine"; << " +9=" << node_name.c_str() + 9;
string engine_plan_string; }
{ if (output_tensors.size() <= slot_number) {
auto trt_engine = output_tensors.resize(slot_number + 1);
infer_object(trt_builder->buildCudaEngine(*converter.network())); }
VLOG(0) << "Built network"; output_tensors.at(slot_number) = {node_def.input(0), node_name};
if (trt_engine.get() == nullptr) { } else {
return tensorflow::errors::Internal("Engine building failure"); VLOG(2) << "Converting node: " << node_def.name() << " , "
<< node_def.op();
TF_RETURN_IF_ERROR(converter.convert_node(node_def));
} }
auto engine_plan = infer_object(trt_engine->serialize());
VLOG(0) << "Serialized engine";
const char* engine_plan_data =
static_cast<const char*>(engine_plan->data());
engine_plan_string =
string(engine_plan_data, engine_plan_data + engine_plan->size());
} }
TF_RETURN_IF_ERROR(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>( for (const auto& output : output_tensors) {
engine_name, engine_name)); auto tensor_or_weights = converter.get_tensor(output.first);
LOG(INFO) << "finished engine " << engine_name << " containing " if (!tensor_or_weights.is_tensor()) {
<< s.subgraph_node_ids.size() << " nodes"; return tensorflow::errors::InvalidArgument(
"Output node '" + output.first + "' is weights not tensor");
}
nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
tensor->setName(output.second.c_str());
if (!tensor) {
return tensorflow::errors::NotFound("Output tensor not found: " +
output.first);
}
VLOG(1) << "Marking output tensor " << output.first << ", as output tensor "
<< output.second;
// Build the TRT op converter.network()->markOutput(*tensor);
tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); }
TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); if (convert_successfully) *convert_successfully = true;
VLOG(0) << "Finished op preparation"; // Build the engine.
VLOG(1) << "Starting engine creation";
engine->reset(builder->buildCudaEngine(*converter.network()));
if (engine->get() == nullptr) {
return tensorflow::errors::Internal("Failed to build TensorRT engine");
}
VLOG(1) << "Finished conversion";
return tensorflow::Status::OK();
}
auto status = op_builder.Attr("serialized_engine", engine_plan_string) tensorflow::Status ConvertSegmentToGraphDef(
.Attr("input_nodes", input_names) const tensorflow::Graph* graph,
.Attr("output_nodes", output_names) const tensorflow::grappler::GraphProperties& graph_properties,
.Attr("OutT", output_dtypes) const std::vector<int>& subgraph_node_ids, // In topological order
.Device(s.device_name_) std::vector<EngineConnection>* connections,
.Finalize(s.trt_node); tensorflow::GraphDef* segment_def, string* common_scope) {
std::set<string> marker_nodes;
// Update connection shapes/data types and add corresponding input/output
// nodes in the segment graphdef.
for (size_t i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
auto outside_node = graph->FindNodeId(connection.outside_id);
if (!outside_node) {
// This should never happen, unless the original graph is problematic.
return tensorflow::errors::NotFound(
"Cannot find node with id ", connection.outside_id, " in the graph.");
}
// Updates the shape and data types of input/output connections.
tensorflow::DataType input_type = tensorflow::DT_FLOAT;
tensorflow::PartialTensorShape partial_shape;
if (connection.is_input_edge) {
if (graph_properties.HasOutputProperties(connection.outside_node_name)) {
auto output_params =
graph_properties.GetOutputProperties(connection.outside_node_name);
auto out_shape = output_params.at(connection.outside_port);
input_type = out_shape.dtype();
std::vector<tensorflow::int64> dims;
partial_shape = out_shape.shape();
connection.outside_shape = partial_shape;
} else {
VLOG(0) << "Unknown output shape" << outside_node->name();
input_type = graph->FindNodeId(connection.outside_id)
->output_type(connection.outside_port);
}
connection.connection_type = input_type;
VLOG(0) << status.ToString() << " finished op building for " << engine_name } else { // output edge
<< " on device " << s.device_name_; if (graph_properties.HasInputProperties(connection.outside_node_name)) {
auto input_params =
graph_properties.GetInputProperties(connection.outside_node_name);
auto in_shape = input_params.at(connection.outside_port);
input_type = in_shape.dtype();
partial_shape = in_shape.shape();
connection.inside_shape = partial_shape;
} else {
input_type = graph->FindNodeId(connection.inside_id)
->output_type(connection.outside_port);
}
connection.connection_type = input_type;
}
// Add dummy input/output nodes to the segment graphdef.
if (connection.is_input_edge) {
const string node_name = StrCat(kInputPHName, connection.port_number);
if (marker_nodes.count(node_name)) {
VLOG(1) << "Reusing input " << node_name << " for the edge "
<< connection.outside_node_name << ":"
<< connection.outside_port << " -> "
<< connection.inside_node_name << ":" << connection.inside_port;
continue;
}
marker_nodes.insert(node_name);
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Placeholder");
auto status = builder.Attr("shape", partial_shape)
.Attr("dtype", input_type)
.Finalize(seg_node);
VLOG(1) << "Constructing input " << node_name << " for the edge "
<< connection.outside_node_name << ":" << connection.outside_port
<< " -> " << connection.inside_node_name << ":"
<< connection.inside_port;
} else {
const string node_name = StrCat(kOutputPHName, connection.port_number);
if (marker_nodes.count(node_name)) {
VLOG(1) << "Reusing output " << node_name << " for the edge "
<< connection.inside_node_name << ":" << connection.inside_port
<< " -> " << connection.outside_node_name << ":"
<< connection.outside_port;
continue;
}
marker_nodes.insert(node_name);
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Identity");
auto status = builder.Input(connection.inside_node_name, 0, input_type)
.Finalize(seg_node);
VLOG(1) << "Constructing output " << node_name << " for the edge "
<< connection.inside_node_name << ":" << connection.inside_port
<< " -> " << connection.outside_node_name << ":"
<< connection.outside_port;
}
} // for each connection.
std::unordered_map<int, int> old_to_new_id_map;
// Copy internal nodes to new graphdef
string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name();
for (const auto node_id : subgraph_node_ids) {
const auto node = graph->FindNodeId(node_id);
local_scope = GetCommonNameScope(local_scope, node->name());
old_to_new_id_map[node_id] = segment_def->node_size();
auto snode = segment_def->add_node();
snode->CopyFrom(node->def());
VLOG(1) << "Copying " << snode->name() << " to subgraph";
}
// Update the inputs of the new input nodes to point to placeholder nodes.
for (int i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
if (!connection.is_input_edge) continue;
auto snode =
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
const string placeholder_name =
StrCat(kInputPHName, connection.port_number);
VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
<< " from " << snode->input(connection.inside_port) << " to "
<< placeholder_name;
snode->set_input(connection.inside_port, placeholder_name);
}
*common_scope = local_scope;
VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph";
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }

View File

@ -22,69 +22,112 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA
#if GOOGLE_TENSORRT #if GOOGLE_TENSORRT
namespace tensorflow { namespace tensorflow {
namespace tensorrt { namespace tensorrt {
static const char* kInputPHName = "InputPH_";
static const char* kOutputPHName = "OutputPH_";
namespace convert { namespace convert {
// TODO(aaroey): use an enum instead.
const int FP32MODE = 0; const int FP32MODE = 0;
const int FP16MODE = 1; const int FP16MODE = 1;
const int INT8MODE = 2; const int INT8MODE = 2;
struct SubGraphParams { struct EngineConnection {
SubGraphParams( EngineConnection(const string& outside, int out_id, int out_port,
tensorflow::Graph& inp_graph, const string& inside, int in_id, int in_port,
const std::set<int>& subgraph_node_id_numbers, bool input_edge, int port)
const std::vector<std::pair<int, int>>& input_indices, : outside_node_name(outside),
const std::vector<std::pair<int, int>>& output_indices, outside_id(out_id),
size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, outside_port(out_port),
const tensorflow::grappler::GraphProperties& current_graph_properties, inside_node_name(inside),
std::unordered_map<string, std::pair<int, string>>* output_edges, inside_id(in_id),
tensorflow::NodeDef* constructed_trt_node, inside_port(in_port),
int engine_precision_mode = FP32MODE, const string& device_name = "", is_input_edge(input_edge),
std::shared_ptr<nvinfer1::IGpuAllocator> allocator = nullptr, port_number(port) {}
int cuda_gpu_id = 0)
: graph(inp_graph),
subgraph_node_ids(subgraph_node_id_numbers),
input_inds(input_indices),
output_inds(output_indices),
max_batch_size(max_supported_batch_size),
max_workspace_size_bytes(max_consumed_workspace_size_bytes),
graph_properties(current_graph_properties),
output_edge_map(output_edges),
trt_node(constructed_trt_node),
precision_mode(engine_precision_mode),
device_name_(device_name),
allocator_(allocator),
cuda_gpu_id_(cuda_gpu_id) {}
tensorflow::Graph& graph; const string outside_node_name;
const std::set<int>& subgraph_node_ids; const int outside_id;
const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx} const int outside_port;
const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx} tensorflow::PartialTensorShape outside_shape;
size_t max_batch_size;
size_t max_workspace_size_bytes; const string inside_node_name;
const tensorflow::grappler::GraphProperties& graph_properties; const int inside_id;
std::unordered_map<string, std::pair<int, string>>* output_edge_map; const int inside_port;
tensorflow::NodeDef* trt_node; tensorflow::PartialTensorShape inside_shape;
const int precision_mode;
const string device_name_; tensorflow::DataType connection_type;
std::shared_ptr<nvinfer1::IGpuAllocator> allocator_; bool is_input_edge;
const int cuda_gpu_id_;
// The port number of the TRT node connecting to this edge.
int port_number;
}; };
// TODO(sami): Replace references with const reference or pointers struct EngineInfo {
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params); EngineInfo()
tensorflow::Status InjectCalibrationNode(SubGraphParams& params); : engine_type(EngineType::TRTStatic),
tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph, max_workspace_size_bytes(0),
tensorflow::Node* c_node); precision_mode(FP32MODE) {}
string engine_name;
string device;
tensorflow::GraphDef segment_graph_def;
// The segment nodes that are on one side of the edges are topological sorted.
std::vector<EngineConnection> connections;
enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
EngineType engine_type;
int64 max_workspace_size_bytes;
int maximum_cached_engines;
std::vector<int> cached_engine_batches;
int precision_mode;
};
// Constructs a graphdef from the segment in the given graph. Adds placeholder
// nodes for input edges (InputPH_*) and identity nodes for output edges
// (OutputPH_*). This function needs to be called before TensorRT nodes
// inserted in order to correctly get sizes from the original graph.
//
// - subgraph_node_ids: the node ids of the subgraph, must be sorted in
// topological order.
// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
// sorted in topological order.
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
const std::vector<int>& subgraph_node_ids,
std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope);
// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
// 'builder' successfully build the engine. If the result is not ok, 'engine'
// will be set to nullptr
// Once returned, 'builder' is not needed any more and can be safely detroyed.
//
// - convert_successfully: indicates whether the converson to TensorRT network
// is successful. This is different than successfully building the engine:
// building can still fail afterwards.
tensorflow::Status ConvertGraphDefToEngine(
const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size,
size_t max_workspace_size_bytes,
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
Logger* logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
bool* convert_successfully);
} // namespace convert } // namespace convert
} // namespace tensorrt } // namespace tensorrt
} // namespace tensorflow } // namespace tensorflow

View File

@ -45,8 +45,24 @@ tensorflow::Status TRTOptimizationPass::Init(
if (params.count("max_batch_size")) { if (params.count("max_batch_size")) {
maximum_batch_size_ = params.at("max_batch_size").i(); maximum_batch_size_ = params.at("max_batch_size").i();
} }
if (params.count("max_workspace_size_bytes")) is_dynamic_op_ = false;
if (params.count("is_dynamic_op")) {
is_dynamic_op_ = params.at("is_dynamic_op").b();
}
if (params.count("cached_engine_batches")) {
auto batch_vec = params.at("cached_engine_batches").list();
batches_.reserve(batch_vec.i_size());
for (const auto i : batch_vec.i()) {
batches_.push_back(i);
}
}
max_cached_batches_ = 1;
if (params.count("maximum_cached_engines")) {
max_cached_batches_ = params.at("maximum_cached_engines").i();
}
if (params.count("max_workspace_size_bytes")) {
maximum_workspace_size_ = params.at("max_workspace_size_bytes").i(); maximum_workspace_size_ = params.at("max_workspace_size_bytes").i();
}
if (params.count("precision_mode")) { if (params.count("precision_mode")) {
string pm = Uppercase(params.at("precision_mode").s()); string pm = Uppercase(params.at("precision_mode").s());
if (pm == "FP32") { if (pm == "FP32") {
@ -175,6 +191,17 @@ tensorflow::Status TRTOptimizationPass::Optimize(
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
PrintDebugInfo(cluster, item); PrintDebugInfo(cluster, item);
} }
// This is a hack to workaround optimizer issue. MetaOptimizer calls
// optimization passes on function objects as well, we should not modify
// generated funcdefs! This is fragile but we don't have any other option
// until framework fixes it.
if (item.id != "tf_graph") {
LOG(WARNING) << name_
<< " is probably called on funcdef! This optimizer must *NOT* "
"be called on function objects.";
*optimized_graph = item.graph;
return tensorflow::Status::OK();
}
int max_dim = -1; int max_dim = -1;
if (item.feed.size()) { if (item.feed.size()) {
for (const auto& f : item.feed) { for (const auto& f : item.feed) {
@ -204,11 +231,22 @@ tensorflow::Status TRTOptimizationPass::Optimize(
} }
tensorflow::grappler::GraphProperties static_graph_properties(item); tensorflow::grappler::GraphProperties static_graph_properties(item);
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
auto status = tensorflow::tensorrt::convert::ConvertAfterShapes( tensorflow::tensorrt::convert::ConversionParams cp;
item.graph, item.fetch, maximum_batch_size_, maximum_workspace_size_, cp.input_graph_def = &item.graph;
optimized_graph, precision_mode_, minimum_segment_size_, cp.output_names = &item.fetch;
static_graph_properties, cluster); cp.max_batch_size = maximum_batch_size_;
cp.max_workspace_size_bytes = maximum_workspace_size_;
cp.output_graph_def = optimized_graph;
cp.precision_mode = precision_mode_;
cp.minimum_segment_size = minimum_segment_size_;
cp.graph_properties = &static_graph_properties;
cp.cluster = cluster;
cp.is_dyn_op = is_dynamic_op_;
cp.cached_engine_batches = batches_;
cp.max_cached_engines = max_cached_batches_;
auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp);
VLOG(2) << optimized_graph->DebugString(); VLOG(2) << optimized_graph->DebugString();
VLOG(1) << "Returning from " << name_;
return status; return status;
} }

View File

@ -61,6 +61,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
int minimum_segment_size_; int minimum_segment_size_;
int precision_mode_; int precision_mode_;
int maximum_batch_size_; int maximum_batch_size_;
bool is_dynamic_op_;
std::vector<int> batches_;
int max_cached_batches_;
int64_t maximum_workspace_size_; int64_t maximum_workspace_size_;
}; };

View File

@ -0,0 +1,37 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
#include <memory>
namespace tensorflow {
namespace tensorrt {
template <typename T>
struct TrtDestroyer {
void operator()(T* t) {
if (t) t->destroy();
}
};
template <typename T>
using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
} // namespace tensorrt
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_

View File

@ -14,8 +14,16 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#include <algorithm>
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -25,144 +33,556 @@ limitations under the License.
#include "cuda/include/cuda_runtime_api.h" #include "cuda/include/cuda_runtime_api.h"
namespace tensorflow { namespace tensorflow {
static ::tensorflow::tensorrt::Logger logger;
using IRuntime = nvinfer1::IRuntime;
using Dims = nvinfer1::Dims;
namespace tensorrt { namespace tensorrt {
static Logger logger;
using ::nvinfer1::IRuntime;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { // A helper class to call done() when destructed for asynchronous execution.
// read serialized_engine // Helps simultaneous execution of native and TRT engines.
OP_REQUIRES_OK(context, class AsyncHelper : public tensorflow::core::RefCounted {
context->GetAttr("serialized_engine", &serialized_engine_)); public:
AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; }
~AsyncHelper() override { done_(); }
// register input output node name in trt_sub_graph private:
OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_)); tensorflow::AsyncOpKernel::DoneCallback done_;
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_)); };
#define TYPECASE(dt, X, Y) \
case dt: { \
return (void*)X->flat<tensorflow::EnumToDataType<dt>::Type>().data(); \
}
void* GetTensorAddress(const Tensor* tensor_ptr) {
auto tensor_type = tensor_ptr->dtype();
switch (tensor_type) {
TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr);
TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr);
TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr);
default: {
LOG(ERROR) << "Unsupported Data type "
<< tensorflow::DataTypeString(tensor_type);
return nullptr;
}
}
} }
void TRTEngineOp::Compute(OpKernelContext* context) { tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
// TODO(samikama) runtime should be taken from a resourcemanager as well. VLOG(1) << "Constructing function handle";
// Only engine should be in the op and context and runtime should be taken auto lib = ctx->function_library();
// from resourcemanager if (lib == nullptr) {
return tensorflow::errors::Internal("Context function library is null");
if (!trt_execution_context_ptr_) {
IRuntime* infer = nvinfer1::createInferRuntime(logger);
#if NV_TENSORRT_MAJOR > 3
auto device = context->device();
auto dev_allocator =
device->GetAllocator(tensorflow::AllocatorAttributes());
if (!dev_allocator) {
LOG(FATAL) << "Can't find device allocator for gpu device "
<< device->name();
}
allocator_ = std::make_shared<TRTDeviceAllocator>(dev_allocator);
infer->setGpuAllocator(allocator_.get());
#endif
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
serialized_engine_.c_str(), serialized_engine_.size(),
PluginFactoryTensorRT::GetInstance()));
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
// Runtime is safe to delete after engine creation
infer->destroy();
serialized_engine_.clear();
} }
int num_binding = context->num_inputs() + context->num_outputs(); auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
std::vector<void*> buffers(num_binding); if (fdef == nullptr) {
return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_,
" can't be found in function library");
}
tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops;
inst_ops.overlay_lib = nullptr;
inst_ops.state_handle = "";
inst_ops.target = ctx->device()->name();
native_func_ = 0;
auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()),
inst_ops, &native_func_);
if (!status.ok()) {
LOG(ERROR) << " Instantiating native function " << funcdef_name_
<< " failed!";
}
return status;
}
size_t binding_index; TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
int num_batch = 0; : AsyncOpKernel(context) {
for (int i = 0; i < context->num_inputs(); i++) { // read serialized_engine
// Grab the input tensor OP_REQUIRES_OK(context,
binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); context->GetAttr("serialized_segment", &serialized_segment_));
OP_REQUIRES_OK(context,
context->GetAttr("workspace_size_bytes", &workspace_size_));
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
if (!static_engine_) {
if (!segment_graph_.ParseFromString(serialized_segment_)) {
LOG(ERROR) << "Parsing segment graph failed!";
context->SetStatus(tensorflow::errors::InvalidArgument(
"Failed to parse segment graphdef!"));
return;
}
serialized_segment_.resize(0);
}
VLOG(1) << "Constructing " << name();
string precision_string;
OP_REQUIRES_OK(context,
context->GetAttr("precision_mode", &precision_string));
string calibration_data;
OP_REQUIRES_OK(context,
context->GetAttr("calibration_data", &calibration_data));
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
if (precision_string == "FP32") {
precision_mode_ = convert::FP32MODE;
} else if (precision_string == "FP16") {
precision_mode_ = convert::FP16MODE;
} else if (precision_string == "INT8") {
precision_mode_ = convert::INT8MODE;
}
calibration_mode_ =
(precision_mode_ == convert::INT8MODE && calibration_data.size() == 0);
if (calibration_data.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
}
native_func_ = tensorflow::kInvalidHandle;
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
&max_cached_engines_));
OP_REQUIRES_OK(context,
context->GetAttr("fixed_input_size", &fixed_input_size_));
OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
&cached_engine_batches_));
std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
if (VLOG_IS_ON(1)) {
string s("Engine Batches= ");
for (auto i : cached_engine_batches_) {
StrAppend(&s, i, " ");
}
VLOG(1) << s;
}
}
const Tensor& input_tensor = context->input(i); void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
const TensorShape& input_shape = input_tensor.shape(); AsyncHelper* helper) {
if (i == 0) { if (!calibration_mode_) {
num_batch = input_shape.dim_size(0); VLOG(1) << "Executing native engine";
if (num_batch > trt_engine_ptr_->getMaxBatchSize()) { }
LOG(FATAL) << "input tensor batch larger than max_batch_size: " std::vector<Tensor> inputs;
<< trt_engine_ptr_->getMaxBatchSize(); std::vector<Tensor>* outputs = new std::vector<Tensor>();
} if (native_func_ == tensorflow::kInvalidHandle) {
} else if (num_batch != input_shape.dim_size(0)) { auto status = ConstructFunctionHandle(ctx);
LOG(FATAL) << "input data inconsistent batch size"; if (!status.ok()) {
LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_;
ctx->SetStatus(status);
return;
}
}
auto lib = ctx->function_library();
tensorflow::FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.rendezvous = ctx->rendezvous();
opts.cancellation_manager = ctx->cancellation_manager();
opts.runner = ctx->runner();
for (int i = 0; i < ctx->num_inputs(); i++) {
inputs.push_back(ctx->input(i));
}
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment " << name();
lib->Run(opts, native_func_, inputs, outputs,
[ctx, outputs, helper](const tensorflow::Status& s) {
tensorflow::core::ScopedUnref sc(helper);
VLOG(1) << "Native Segment completed";
if (!s.ok()) {
ctx->SetStatus(s);
return;
}
for (size_t t = 0; t < outputs->size(); ++t) {
ctx->set_output(t, outputs->at(t));
}
delete outputs;
});
}
void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
AsyncHelper* helper) {
helper->Ref();
tensorflow::core::ScopedUnref sc(helper);
// TODO(aaroey): remove the ResourceMgr singleton.
auto trt_rm = TRTResourceManager::instance();
auto res_mgr = trt_rm->getManager("TRTCalibration");
TRTCalibrationResource* calib_res = nullptr;
auto status = res_mgr->LookupOrCreate(
funcdef_name_, "Calibrator", &calib_res,
{[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status {
return this->AllocateCalibrationResources(ctx, cr);
}});
if (!status.ok()) {
ctx->SetStatus(status);
return;
}
int num_inputs = ctx->num_inputs();
// Pass input data to calibrator
std::unordered_map<string, void*> input_data;
for (int i = 0; i < num_inputs; i++) {
const Tensor& t = ctx->input(i);
void* data_address = GetTensorAddress(&t);
if (data_address == nullptr) {
ctx->SetStatus(tensorflow::errors::InvalidArgument(
"Unsupported data type encountered in input ", i));
return;
}
// Check the allocated buffer is sufficient for input
const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx);
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
input_data.emplace(StrCat(kInputPHName, i), data_address);
}
VLOG(2) << "Filled map for sending";
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
->CudaStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
}
int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) {
int num_batch = ctx->input(0).shape().dim_size(0);
int smallest_engine = 0;
for (const auto i : cached_engine_batches_) {
if (i >= num_batch) {
smallest_engine = i;
break; break;
} }
auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); }
// TODO(sami): Need an LRU here
if (smallest_engine == 0) {
if (max_cached_engines_ > cached_engine_batches_.size()) {
smallest_engine = num_batch;
cached_engine_batches_.push_back(num_batch);
VLOG(1) << "Running with batch size " << num_batch;
} else {
string s("Engine buffer is full. buffer limit= ");
StrAppend(&s, max_cached_engines_, ", current entries= ");
for (auto i : cached_engine_batches_) StrAppend(&s, i, ", ");
StrAppend(&s, "Requested batch= ", num_batch);
LOG(ERROR) << s;
ctx->SetStatus(tensorflow::errors::ResourceExhausted(
"Requested batch size is not available and engine cache is full"));
return -1;
}
}
return smallest_engine;
}
void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
tensorflow::AsyncOpKernel::DoneCallback done) {
auto helper = new AsyncHelper(done);
tensorflow::core::ScopedUnref sc(helper);
if (calibration_mode_) {
ExecuteCalibration(ctx, helper);
return;
}
const int smallest_engine = GetEngineBatch(ctx);
if (smallest_engine < 0) return; // GetEngineBatch already set the status.
const int num_batch = ctx->input(0).shape().dim_size(0);
auto& engine_ctx_pair = GetEngine(smallest_engine, ctx);
auto& trt_engine_ptr = engine_ctx_pair.first;
if (!trt_engine_ptr) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
<< " failed Running native segment";
ExecuteNativeSegment(ctx, helper);
return;
}
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
std::vector<void*> buffers(num_binding);
for (int i = 0; i < ctx->num_inputs(); i++) {
const string inp_name = StrCat(kInputPHName, i);
const size_t binding_index =
trt_engine_ptr->getBindingIndex(inp_name.c_str());
const Tensor& input_tensor = ctx->input(i);
const TensorShape& input_shape = input_tensor.shape();
if (num_batch != input_shape.dim_size(0)) {
LOG(ERROR) << "input data inconsistent batch size";
ctx->SetStatus(tensorflow::errors::FailedPrecondition(
"Different batch sizes between input tensors"));
return;
}
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) { switch (dtype) {
case nvinfer1::DataType::kFLOAT: case nvinfer1::DataType::kFLOAT:
buffers[binding_index] = (void*)(input_tensor.flat<float>().data()); buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
break; break;
case nvinfer1::DataType::kHALF: case nvinfer1::DataType::kHALF:
LOG(FATAL) << "half size is not supported yet!"; LOG(ERROR) << "FP16 inputs are not supported yet!";
break; ctx->SetStatus(tensorflow::errors::InvalidArgument(
"FP16 inputs are not supported!"));
return;
case nvinfer1::DataType::kINT8: case nvinfer1::DataType::kINT8:
LOG(FATAL) << "int8 is not supported yet!"; LOG(ERROR) << "INT8 inputs are not supported yet!";
break; ctx->SetStatus(tensorflow::errors::InvalidArgument(
"INT8 inputs are not supported!"));
return;
default: default:
LOG(FATAL) << "Unknown data type: " << int(dtype); LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
break; ctx->SetStatus(tensorflow::errors::InvalidArgument(
"Unknown output TRT data type! ", static_cast<int>(dtype)));
return;
} }
} }
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) { for (int i = 0; i < ctx->num_outputs(); i++) {
// This is bad that we have to reallocate output buffer every run.
// Create an output tensor // Create an output tensor
binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str()); const string output_name = StrCat(kOutputPHName, i);
const size_t binding_index =
trt_engine_ptr->getBindingIndex(output_name.c_str());
Tensor* output_tensor = nullptr; Tensor* output_tensor = nullptr;
TensorShape output_shape; TensorShape output_shape;
if (binding_index != -1) { if (binding_index != -1) {
auto dims = trt_engine_ptr_->getBindingDimensions(binding_index); auto dims = trt_engine_ptr->getBindingDimensions(binding_index);
std::vector<int> trt_shape(dims.nbDims + 1); std::vector<int> trt_shape(dims.nbDims + 1);
trt_shape[0] = num_batch; trt_shape[0] = num_batch;
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
OP_REQUIRES_OK(context, OP_REQUIRES_OK(
TensorShapeUtils::MakeShape( ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
trt_shape.data(), trt_shape.size(), &output_shape)); &output_shape));
} else { } else {
LOG(FATAL) << "output node not found, at " << output_nodes_[i]; LOG(ERROR) << "output node not found, at " << output_name;
break; ctx->SetStatus(tensorflow::errors::Internal("output ", output_name,
" couldn't be found!"));
return;
} }
auto status = ctx->allocate_output(i, output_shape, &output_tensor);
OP_REQUIRES_OK(context, if (!status.ok()) {
context->allocate_output(i, output_shape, &output_tensor)); LOG(ERROR) << "Allocating output failed with " << status;
auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); ctx->SetStatus(status);
return;
}
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) { switch (dtype) {
case nvinfer1::DataType::kFLOAT: case nvinfer1::DataType::kFLOAT:
buffers[binding_index] = buffers[binding_index] =
reinterpret_cast<void*>(output_tensor->flat<float>().data()); reinterpret_cast<void*>(output_tensor->flat<float>().data());
break; break;
case nvinfer1::DataType::kHALF: case nvinfer1::DataType::kHALF:
LOG(FATAL) << "half size is not supported yet!"; LOG(ERROR) << "half size is not supported yet!";
break; ctx->SetStatus(tensorflow::errors::InvalidArgument(
"Half outputs are not supported!"));
return;
case nvinfer1::DataType::kINT8: case nvinfer1::DataType::kINT8:
LOG(FATAL) << "int8 is not supported yet!"; LOG(ERROR) << "int8 is not supported yet!";
break; ctx->SetStatus(tensorflow::errors::InvalidArgument(
"INT8 outputs are not supported!"));
return;
default: default:
LOG(FATAL) << "Unknown data type: " << int(dtype); LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
break; ctx->SetStatus(tensorflow::errors::InvalidArgument(
"Unsupported output data type! ", static_cast<int>(dtype)));
return;
} }
} }
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL( const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(context->op_device_context() reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream() ->stream()
->implementation() ->implementation()
->CudaStreamMemberHack())); ->CudaStreamMemberHack()));
// TODO(jie): trt enqueue does not return error // TODO(jie): trt enqueue does not return error
auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], auto& trt_execution_context_ptr = engine_ctx_pair.second;
*stream, nullptr); auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream,
VLOG(2) << "enqueue returns: " << ret; nullptr);
if (!ret) {
LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name();
ctx->SetStatus(tensorflow::errors::Internal(
"Failed to enqueue batch for TRT engine: ", name()));
}
// sync should be done by TF. // sync should be done by TF.
} }
TRTEngineOp::~TRTEngineOp() { TRTEngineOp::~TRTEngineOp() {
// Order matters! // We need to manually destroy the engine and execution context before
trt_execution_context_ptr_.reset(); // the allocator is destructed.
trt_engine_ptr_.reset(); for (auto& eng : engine_map_) {
eng.second.first.reset();
eng.second.second.reset();
}
allocator_.reset(); allocator_.reset();
} }
nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) {
if (allocator_) return allocator_.get();
auto device = ctx->device();
auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes());
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
ctx->SetStatus(tensorflow::errors::Internal(
"Can't get device allocator for device ", device->name()));
return nullptr;
}
allocator_.reset(new TRTDeviceAllocator(alloc));
return allocator_.get();
}
TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
OpKernelContext* ctx) {
static EngineCtxPair null_pair = {
TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr),
TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)};
// TODO(sami): This method needs to be re-written to use resource manager and
// with LRU mechanism option.
tensorflow::mutex_lock lock(engine_mutex_);
if (static_engine_) {
if (engine_map_.size()) {
if (engine_map_.begin()->first >= batch_size) {
return engine_map_.begin()->second;
}
return null_pair;
}
TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
#if NV_TENSORRT_MAJOR > 3
auto allocator = GetAllocator(ctx);
if (allocator == nullptr) {
// GetAllocator already set the Status.
return null_pair;
}
infer->setGpuAllocator(allocator);
#endif
TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
infer->deserializeCudaEngine(serialized_segment_.c_str(),
serialized_segment_.size(), nullptr));
auto raw_static_engine = static_engine.get();
const auto max_batch_size = raw_static_engine->getMaxBatchSize();
engine_map_[max_batch_size] = {
std::move(static_engine),
TrtUniquePtrType<nvinfer1::IExecutionContext>(
raw_static_engine->createExecutionContext())};
// Runtime is safe to delete after engine creation
serialized_segment_.clear();
if (max_batch_size < batch_size) return null_pair;
return engine_map_.at(max_batch_size);
} // static_engine_
// Handle the dynamic engine case.
auto engine_it = engine_map_.find(batch_size);
if (engine_it == engine_map_.end() &&
engine_map_.size() < (size_t)max_cached_engines_) {
nvinfer1::IGpuAllocator* allocator = nullptr;
#if NV_TENSORRT_MAJOR > 3
allocator = GetAllocator(ctx);
if (allocator == nullptr) {
// GetAllocator already set the Status.
return null_pair;
}
#endif
std::vector<tensorflow::PartialTensorShape> shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
shapes.emplace_back(ctx->input(i).shape());
}
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
bool convert_successfully = false;
VLOG(0) << name() << " Constructing a new engine with batch size "
<< batch_size;
// Up to this point, calibrator_ can never be empty, since otherwise it
// means calibration_mode_ is true and this path won't get executed.
auto status = convert::ConvertGraphDefToEngine(
segment_graph_, precision_mode_, batch_size, workspace_size_, shapes,
&logger, allocator, calibrator_.get(), &engine, &convert_successfully);
if (!status.ok()) {
if (convert_successfully) {
// This means it fail to build the engine even when the network is built
// successfully, probably due to internal issues. In this case we don't
// retry in the future.
engine_map_[batch_size] = {nullptr, nullptr};
}
LOG(ERROR) << "Engine creation for batch size " << batch_size
<< " failed " << status;
ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!"));
return null_pair;
}
VLOG(1) << "Conversion is done";
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
engine->createExecutionContext());
engine_map_[batch_size] = {std::move(engine), std::move(exec_context)};
}
return engine_map_.at(batch_size);
}
tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) {
auto cres = new TRTCalibrationResource();
*cr = cres;
// Get the allocator.
auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes());
if (!alloc) {
LOG(WARNING) << "Can't get device allocator will not be able to "
"allocate memory from TensorFlow memory pool";
cres->allocator_.reset(new TRTCudaAllocator);
} else {
cres->allocator_.reset(new TRTDeviceAllocator(alloc));
}
// Get the input shapes.
const int batch_size = ctx->input(0).dim_size(0);
const int num_inputs = ctx->num_inputs();
std::vector<tensorflow::PartialTensorShape> shapes;
dev_tensors_.resize(num_inputs);
VLOG(1) << " Constructing calibrator";
for (int i = 0; i < num_inputs; i++) {
// allocate workspace on device for inputs
const tensorflow::Tensor& t = ctx->input(i);
shapes.emplace_back(t.shape());
Tensor* device_tensor;
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor));
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
void* device_address = GetTensorAddress(device_tensor);
if (device_address == nullptr) {
return tensorflow::errors::InvalidArgument(
"Unsupported data type encountered in input ", i);
}
device_buffers_.emplace(
StrCat(kInputPHName, i),
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
}
cres->calibrator_.reset(
new TRTInt8Calibrator(device_buffers_, batch_size, name()));
const string label(name());
auto segment_graph = &segment_graph_;
const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
if (cuda_gpu_id < 0) {
LOG(ERROR) << "Can't get gpu_device_info from context->device()";
return tensorflow::errors::InvalidArgument(
"Context->device doesn't contain device info!");
}
const int64 workspace_size_bytes = workspace_size_;
cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
cuda_gpu_id, workspace_size_bytes]() {
VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
<< ", Calibration Resource @ " << cres;
auto err = cudaSetDevice(cuda_gpu_id);
if (err != cudaSuccess) {
// TODO(aaroey): should return error here.
LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
<< " in calibration thread";
}
// ConvertGraphDefToEngine() will try to build the engine. This thread
// will loop inside buildCudaEngine() consuming the calibration data
// that is set by the TF op, and drive the builder until calibrator returns
// false. Engine is discarded after calibration table is generated
//
// TODO(aaroey): maybe setting the max batch size using the python
// calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
*segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(),
workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
cres->calibrator_.get(), &cres->engine_,
/*convert_successfully=*/nullptr);
if (!s.ok()) {
LOG(ERROR) << "Calibration failed: " << s;
cres->calibrator_->setDone(); // Ignore further pushes
}
VLOG(1) << "Calibration loop terminated " << label;
}));
VLOG(1) << "initialized calibrator resource";
return tensorflow::Status::OK();
}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
} // namespace tensorrt } // namespace tensorrt

View File

@ -19,9 +19,14 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/mutex.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA
#if GOOGLE_TENSORRT #if GOOGLE_TENSORRT
@ -30,32 +35,95 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tensorrt { namespace tensorrt {
class Logger; class TRTInt8Calibrator;
class TRTCalibrationResource;
class AsyncHelper;
// TODO(Sami): Remove this file? // TODO(Sami): Remove this file?
class TRTEngineOp : public OpKernel {
// This OP can construct TRTEngine on the fly and if construction of engine
// fails, executes equivalent subgraph as a TensorFlow function.
class TRTEngineOp : public AsyncOpKernel {
public: public:
explicit TRTEngineOp(OpKernelConstruction* context); explicit TRTEngineOp(OpKernelConstruction* context);
void Compute(OpKernelContext* context) override; void ComputeAsync(OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
~TRTEngineOp(); ~TRTEngineOp();
private: private:
template <typename T> // Execute calibration
struct Destroyer { void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
void operator()(T* d) { d->destroy(); }
}; // Construct a function handle for executing native funcdef graph
Status ConstructFunctionHandle(OpKernelContext* ctx);
// Execute replaced native segment as function Op.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
// Allocate necessary resources for calibration
Status AllocateCalibrationResources(OpKernelContext* ctx,
TRTCalibrationResource** cr);
template <typename T>
using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
// TODO(samikama): context should go to a resource manager! // TODO(samikama): context should go to a resource manager!
destroyed_ptr<nvinfer1::IExecutionContext> trt_execution_context_ptr_; typedef std::pair<TrtUniquePtrType<nvinfer1::ICudaEngine>,
TrtUniquePtrType<nvinfer1::IExecutionContext>>
EngineCtxPair;
EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx);
// Return engine batch closest to input batch.
int GetEngineBatch(OpKernelContext* ctx);
nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx);
// map to keep engines and their execution context for given batch size.
std::unordered_map<int, EngineCtxPair> engine_map_;
std::vector<string> input_nodes_; std::vector<string> input_nodes_;
std::vector<string> output_nodes_; std::vector<string> output_nodes_;
std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
string serialized_engine_; // keep device allocator for TRT.
std::unique_ptr<TRTDeviceAllocator> allocator_;
// serialized protobuf segment or trt engine depending on static_engine_ flag.
string serialized_segment_;
// Name of the function for TF native execution of the segment.
string funcdef_name_;
// GraphDef representation of the segment.
GraphDef segment_graph_;
// Lookup table for temporary staging areas of input tensors for calibration.
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
// Temporary staging areas for calibration inputs.
std::vector<PersistentTensor> dev_tensors_;
// Engine Precision mode.
int precision_mode_;
// Whether engine is constructed during the conversion or needs to be
// constructed from protobuf segment.
bool static_engine_;
// Whether to calibrate INT8 engine.
bool calibration_mode_;
// Whether non-batch ranks of the inputs are assumed to be fixed or not for
// engine construction.
bool fixed_input_size_;
// Batches of the cached engines
std::vector<int> cached_engine_batches_;
// Maximum number of cached engines
int max_cached_engines_;
int64 workspace_size_;
mutex engine_mutex_;
FunctionLibraryRuntime::Handle native_func_;
// The finalized calibrator for inference.
std::unique_ptr<TRTInt8Calibrator> calibrator_;
}; };
} // namespace tensorrt } // namespace tensorrt

View File

@ -28,11 +28,19 @@ extern Status TRTEngineOpShapeInference(InferenceContext* c);
} }
REGISTER_OP("TRTEngineOp") REGISTER_OP("TRTEngineOp")
.Attr("serialized_engine: string") .Attr("serialized_segment: string")
.Attr("input_nodes: list(string)") .Attr("input_shapes: list(shape)")
.Attr("output_nodes: list(string)") .Attr("output_shapes: list(shape)")
.Attr("InT: list({float32})") .Attr("segment_funcdef_name: string")
.Attr("OutT: list({float32})") .Attr("InT: list({int8,float16,float32})")
.Attr("OutT: list({int8,float16,float32})")
.Attr("static_engine: bool = true")
.Attr("fixed_input_size: bool = true")
.Attr("cached_engine_batches: list(int) = []")
.Attr("max_cached_engines_count: int = 1")
.Attr("workspace_size_bytes: int")
.Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}")
.Attr("calibration_data: string = ''")
.Input("in_tensor: InT") .Input("in_tensor: InT")
.Output("out_tensor: OutT") .Output("out_tensor: OutT")
.SetShapeFn(shape_inference::TRTEngineOpShapeInference); .SetShapeFn(shape_inference::TRTEngineOpShapeInference);

View File

@ -21,6 +21,8 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long # pylint: disable=unused-import,line-too-long
import six as _six import six as _six
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
@ -29,7 +31,9 @@ from tensorflow.python.framework import errors_impl as _impl
from tensorflow.python.framework import meta_graph from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
# pylint: enable=unused-import,line-too-long # pylint: enable=unused-import,line-too-long
@ -40,7 +44,10 @@ def create_inference_graph(input_graph_def,
max_batch_size=1, max_batch_size=1,
max_workspace_size_bytes=2 << 20, max_workspace_size_bytes=2 << 20,
precision_mode="FP32", precision_mode="FP32",
minimum_segment_size=3): minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batches=[]):
"""Python wrapper for the TRT transformation. """Python wrapper for the TRT transformation.
Args: Args:
@ -51,6 +58,10 @@ def create_inference_graph(input_graph_def,
precision_mode: one of 'FP32', 'FP16' and 'INT8' precision_mode: one of 'FP32', 'FP16' and 'INT8'
minimum_segment_size: the minimum number of nodes required for a subgraph to minimum_segment_size: the minimum number of nodes required for a subgraph to
be replaced by TRTEngineOp. be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
cached_engine_batches: batch sizes used to pre-create cached engines.
Returns: Returns:
New GraphDef with TRTEngineOps placed in graph replacing subgraphs. New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
@ -65,6 +76,30 @@ def create_inference_graph(input_graph_def,
"It should be one of {}").format( "It should be one of {}").format(
precision_mode, "{'FP32', 'FP16', 'INT8'}")) precision_mode, "{'FP32', 'FP16', 'INT8'}"))
mode = supported_precision_modes[precision_mode.upper()] mode = supported_precision_modes[precision_mode.upper()]
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
version_mismatch = False
if loaded_version[0] < compiled_version[0]:
tf_logging.error(
"TensorRT version mismatch. Tensorflow was compiled against " +
"TensorRT %s but library loaded from environment is TensorRT %s" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])) +
". Please make sure that correct version of TensorRT " +
"is available in the system and added to ldconfig or LD_LIBRARY_PATH"
)
raise RuntimeError("Incompatible TensorRT library version")
for i in zip(loaded_version, compiled_version):
if i[0] != i[1]:
tf_logging.warn("TensorRT mismatch. Compiled against version " +
"%s, but loaded %s. Things may not work" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])))
version_mismatch = True
break
if not version_mismatch:
tf_logging.info("Running against TensorRT version %s" % ".".join(
[str(x) for x in loaded_version]))
def py2bytes(inp): def py2bytes(inp):
return inp return inp
@ -100,7 +135,9 @@ def create_inference_graph(input_graph_def,
# pair or strings where first one is encoded status and the second # pair or strings where first one is encoded status and the second
# one is the transformed graphs protobuf string. # one is the transformed graphs protobuf string.
out = trt_convert(input_graph_def_str, out_names, max_batch_size, out = trt_convert(input_graph_def_str, out_names, max_batch_size,
max_workspace_size_bytes, mode, minimum_segment_size) max_workspace_size_bytes, mode, minimum_segment_size,
is_dynamic_op, maximum_cached_engines,
cached_engine_batches)
status = to_string(out[0]) status = to_string(out[0])
output_graph_def_string = out[1] output_graph_def_string = out[1]
del input_graph_def_str # Save some memory del input_graph_def_str # Save some memory
@ -120,11 +157,12 @@ def create_inference_graph(input_graph_def,
return output_graph_def return output_graph_def
def calib_graph_to_infer_graph(calibration_graph_def): def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
"""Convert an existing calibration graph to inference graph. """Convert an existing calibration graph to inference graph.
Args: Args:
calibration_graph_def: the calibration GraphDef object with calibration data calibration_graph_def: the calibration GraphDef object with calibration data
is_dynamic_op: whether to create dynamic static engines from calibration
Returns: Returns:
New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
Raises: Raises:
@ -141,9 +179,16 @@ def calib_graph_to_infer_graph(calibration_graph_def):
to_string = py2string to_string = py2string
else: else:
to_string = py3string to_string = py3string
is_calib_graph = False
for n in calibration_graph_def.node:
if n.op == "TRTEngineOp":
is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
if not is_calib_graph:
tf_logging.error(
"Not a calib graph. Doesn't seem to contain any calibration nodes.")
return None
graph_str = calibration_graph_def.SerializeToString() graph_str = calibration_graph_def.SerializeToString()
out = calib_convert(graph_str) out = calib_convert(graph_str, is_dynamic_op)
status = to_string(out[0]) status = to_string(out[0])
output_graph_def_string = out[1] output_graph_def_string = out[1]
del graph_str # Save some memory del graph_str # Save some memory

View File

@ -50,7 +50,7 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator)
} }
void TRTDeviceAllocator::free(void* memory) { void TRTDeviceAllocator::free(void* memory) {
VLOG(2) << "Deallocating " << memory; VLOG(2) << "Deallocating @ " << memory;
allocator_->DeallocateRaw(memory); allocator_->DeallocateRaw(memory);
} }

View File

@ -16,7 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ #ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ #define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
@ -52,7 +51,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator {
// Allocator implementation wrapping TF device allocators. // Allocator implementation wrapping TF device allocators.
public: public:
TRTDeviceAllocator(tensorflow::Allocator* allocator); TRTDeviceAllocator(tensorflow::Allocator* allocator);
virtual ~TRTDeviceAllocator() {} virtual ~TRTDeviceAllocator() {
VLOG(1) << "Destroying allocator attached to " << allocator_->Name();
}
void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override; void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override;
void free(void* memory) override; void free(void* memory) override;

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include <atomic> #include <atomic>
#include <chrono>
#include <unordered_map> #include <unordered_map>
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -37,15 +36,22 @@ TRTInt8Calibrator::TRTInt8Calibrator(
: batch_size_(batch_size), : batch_size_(batch_size),
done_(false), done_(false),
dev_buffers_(dev_buffers), dev_buffers_(dev_buffers),
calib_running_(false), calib_running_(true),
batch_is_set_(false), batch_is_set_(false),
engine_name_(engine_name) {} engine_name_(engine_name) {}
TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
: batch_size_(0),
done_(false),
calib_running_(false),
batch_is_set_(false),
calibration_table_(calib_data) {}
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data, bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream) { const cudaStream_t stream) {
tensorflow::mutex_lock lock(cond_mtx_); tensorflow::mutex_lock lock(cond_mtx_);
while ((calib_running_ || batch_is_set_) && // wait while calibration is running.
!done_) { // wait while calibration is running while ((calib_running_ || batch_is_set_) && !done_) {
cond_.wait(lock); cond_.wait(lock);
} }
if (done_) return false; if (done_) return false;
@ -59,8 +65,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
} }
const auto& d = devptr->second; const auto& d = devptr->second;
// TODO(aaroey): we should not use sync copy on default stream. Make sure
// stream->ThenMemcpy() is used in future PRs.
// TODO(sami,aaroey): Need to figure out a way to ensure synchronization // TODO(sami,aaroey): Need to figure out a way to ensure synchronization
// between stream, perhaps using a tensor? // between stream, perhaps using a tensor?
auto status = cudaMemcpyAsync(d.first, it.second, d.second, auto status = cudaMemcpyAsync(d.first, it.second, d.second,
@ -84,13 +88,11 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
tensorflow::mutex_lock lock(cond_mtx_); tensorflow::mutex_lock lock(cond_mtx_);
calib_running_ = false; calib_running_ = false;
cond_.notify_all(); cond_.notify_all();
while ((!batch_is_set_ && !done_)) { // wait until new batch arrives // wait until new batch arrives
while ((!batch_is_set_ && !done_)) {
cond_.wait(lock); cond_.wait(lock);
}
if (done_) {
return false;
} }
if (done_) return false;
for (int i = 0; i < num_bindings; i++) { for (int i = 0; i < num_bindings; i++) {
auto it = dev_buffers_.find(names[i]); auto it = dev_buffers_.find(names[i]);
@ -107,7 +109,9 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
} }
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
return nullptr; if (calibration_table_.empty()) return nullptr;
length = calibration_table_.size();
return calibration_table_.data();
} }
void TRTInt8Calibrator::setDone() { void TRTInt8Calibrator::setDone() {
@ -117,7 +121,11 @@ void TRTInt8Calibrator::setDone() {
} }
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
std::size_t length) {} std::size_t length) {
calibration_table_ = string((const char*)ptr, length);
VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr
<< " length=" << length;
}
TRTInt8Calibrator::~TRTInt8Calibrator() { TRTInt8Calibrator::~TRTInt8Calibrator() {
VLOG(1) << "Destroying calibrator for " << engine_name_; VLOG(1) << "Destroying calibrator for " << engine_name_;
} }

View File

@ -39,29 +39,48 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
TRTInt8Calibrator( TRTInt8Calibrator(
const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers, const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
int batch_size, string engine_name); int batch_size, string engine_name);
TRTInt8Calibrator(const string& calibration_data);
~TRTInt8Calibrator();
int getBatchSize() const override; int getBatchSize() const override;
bool getBatch(void* bindings[], const char* names[], bool getBatch(void* bindings[], const char* names[],
int num_bindings) override; int num_bindings) override;
bool setBatch(const std::unordered_map<string, void*>& data, bool setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream); const cudaStream_t stream);
void setDone(); void setDone();
// If not null, calibration is skipped.
const void* readCalibrationCache(std::size_t& length) override; const void* readCalibrationCache(std::size_t& length) override;
void writeCalibrationCache(const void* ptr, std::size_t length) override; void writeCalibrationCache(const void* ptr, std::size_t length) override;
~TRTInt8Calibrator();
const string& getCalibrationTableAsString() { return calibration_table_; }
private: private:
const int batch_size_; const int batch_size_;
tensorflow::mutex cond_mtx_; // mutex for condition_variable
tensorflow::condition_variable cond_; // condition variable to implement // mutex for condition_variable
// producer-consumer queue for tensorflow::mutex cond_mtx_;
// calibration
// condition variable to implement producer-consumer queue for calibration
tensorflow::condition_variable cond_;
// Is calibration finished?
bool done_; bool done_;
const std::unordered_map<string, std::pair<void*, size_t>>
dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with // Map to keep tensorrt input buffers and sizes keyed with buffer names
// buffer names const std::unordered_map<string, std::pair<void*, size_t>> dev_buffers_;
bool calib_running_; bool calib_running_;
bool batch_is_set_; bool batch_is_set_;
string engine_name_; string engine_name_;
string calibration_table_;
}; };
} // namespace tensorrt } // namespace tensorrt

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
@ -34,50 +35,48 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace tensorrt { namespace tensorrt {
class TRTCalibrationResource : public tensorflow::ResourceBase { class TRTCalibrationResource : public tensorflow::ResourceBase {
public: public:
TRTCalibrationResource()
: calibrator_(nullptr),
builder_(nullptr),
network_(nullptr),
engine_(nullptr),
logger_(nullptr),
thr_(nullptr) {}
~TRTCalibrationResource() { ~TRTCalibrationResource() {
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
builder_.reset();
engine_.reset();
// We need to manually destroy the builder and engine before the allocator
// is destroyed.
allocator_.reset();
} }
string DebugString() override { string DebugString() override {
std::stringstream oss; std::stringstream oss;
oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl using std::dec;
<< " Builder = " << std::hex << builder_ << std::dec << std::endl using std::endl;
<< " Network = " << std::hex << network_ << std::dec << std::endl using std::hex;
<< " Engine = " << std::hex << engine_ << std::dec << std::endl oss << " Calibrator = " << hex << calibrator_.get() << dec << endl
<< " Logger = " << std::hex << logger_ << std::dec << std::endl << " Builder = " << hex << builder_.get() << dec << endl
<< " Allocator = " << std::hex << allocator_.get() << std::dec << " Engine = " << hex << engine_.get() << dec << endl
<< std::endl << " Logger = " << hex << &logger_ << dec << endl
<< " Thread = " << std::hex << thr_ << std::dec << std::endl; << " Allocator = " << hex << allocator_.get() << dec << endl
<< " Thread = " << hex << thr_.get() << dec << endl;
return oss.str(); return oss.str();
} }
TRTInt8Calibrator* calibrator_; std::unique_ptr<TRTInt8Calibrator> calibrator_;
nvinfer1::IBuilder* builder_; TrtUniquePtrType<nvinfer1::IBuilder> builder_;
nvinfer1::INetworkDefinition* network_; TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
nvinfer1::ICudaEngine* engine_; std::unique_ptr<nvinfer1::IGpuAllocator> allocator_;
std::shared_ptr<nvinfer1::IGpuAllocator> allocator_; tensorflow::tensorrt::Logger logger_;
tensorflow::tensorrt::Logger* logger_;
// TODO(sami): Use threadpool threads! // TODO(sami): Use threadpool threads!
std::thread* thr_; std::unique_ptr<std::thread> thr_;
}; };
class TRTWeightStore : public tensorflow::ResourceBase { class TRTWeightStore {
public: public:
TRTWeightStore() {} TRTWeightStore() {}
virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); } virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); }
string DebugString() override { string DebugString() {
std::stringstream oss; std::stringstream oss;
size_t len_bytes = 0; size_t len_bytes = 0;
for (const auto& v : store_) { for (const auto& v : store_) {

View File

@ -29,8 +29,9 @@ namespace tensorflow {
namespace tensorrt { namespace tensorrt {
namespace segment { namespace segment {
// vector of segments, each entry contains a device name and a set of nodes in // Vector of segments, each entry contains a set of node names and a device name
// segment // in the segment.
// TODO(aaroey): use node pointer instead of node name.
using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>; using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>;
struct SegmentOptions { struct SegmentOptions {
@ -48,6 +49,8 @@ struct SegmentOptions {
// in the vector describes a subgraph by giving a set of the names of // in the vector describes a subgraph by giving a set of the names of
// all the NodeDefs in that subgraph. // all the NodeDefs in that subgraph.
// @return the status. // @return the status.
//
// TODO(aaroey): remove this method.
tensorflow::Status SegmentGraph( tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef, const tensorflow::GraphDef& gdef,
const std::function<bool(const tensorflow::Node*)>& candidate_fn, const std::function<bool(const tensorflow::Node*)>& candidate_fn,

View File

@ -29,61 +29,35 @@ namespace tensorflow {
namespace shape_inference { namespace shape_inference {
tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
tensorflow::tensorrt::Logger logger; std::vector<tensorflow::TensorShape> shapes;
string serialized_engine; for (int i = 0; i < context->num_outputs(); ++i) {
TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); context->set_output(i, context->UnknownShape());
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(),
tensorrt::PluginFactoryTensorRT::GetInstance());
int num_batch = -1;
std::vector<::tensorflow::DataType> input_type;
TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type));
for (size_t i = 0; i < context->num_inputs(); i++) {
// Check if input shape is legit
auto input_shape = context->input(i);
for (int j = 0; j < context->Rank(input_shape); j++) {
auto dim_handler = context->Dim(input_shape, j);
if (j == 0) {
if (i == 0) {
num_batch = context->Value(dim_handler);
} else if (num_batch != context->Value(dim_handler)) {
// TODO(jie): TensorRT engine requires consistent batch between inputs
// tensors. Segmenter should be aware of this.
LOG(FATAL) << "TensorRT engine requires consistent batch size";
}
}
}
} }
auto status = context->GetAttr("input_shapes", &shapes);
// Arrange input here // it is ok to not to have shapes
std::vector<string> input_nodes; if (!status.ok()) return Status::OK();
TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes)); if ((int)shapes.size() != context->num_inputs()) return Status::OK();
bool different_input = false;
// Arrange output here for (int i = 0; i < context->num_inputs(); ++i) {
std::vector<string> output_nodes; if (shapes.at(i) != context->input_tensor(i)->shape())
TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes)); different_input = true;
for (size_t i = 0; i < output_nodes.size(); i++) { }
int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str()); if (different_input) return Status::OK();
ShapeHandle output_shape; shapes.resize(0);
std::vector<DimensionHandle> dim_vec; status = context->GetAttr("output_shapes", &shapes);
dim_vec.emplace_back(context->MakeDim(num_batch)); if (!status.ok()) return Status::OK();
if (binding_index != -1) { if ((int)shapes.size() != context->num_outputs()) return Status::OK();
auto dims = trt_engine->getBindingDimensions(binding_index); std::vector<ShapeHandle> shape_handles(shapes.size());
for (int j = 0; j < dims.nbDims; j++) { for (size_t i = 0; i < shapes.size(); ++i) {
dim_vec.emplace_back(context->MakeDim(dims.d[j])); status =
} context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i));
} else { if (!status.ok()) return Status::OK();
LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i]; }
} for (int i = 0; i < context->num_outputs(); ++i) {
output_shape = context->MakeShape(dim_vec); context->set_output(i, shape_handles.at(i));
context->set_output(i, output_shape);
} }
return Status::OK(); return Status::OK();
} }
} // namespace shape_inference } // namespace shape_inference
} // namespace tensorflow } // namespace tensorflow

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import argparse import argparse
import numpy as np import numpy as np
import six as _six
# normally we should do import tensorflow as tf and then # normally we should do import tensorflow as tf and then
# tf.placeholder, tf.constant, tf.nn.conv2d etc but # tf.placeholder, tf.constant, tf.nn.conv2d etc but
@ -35,10 +36,75 @@ from tensorflow.python.framework import dtypes as dtypes
from tensorflow.python.framework import importer as importer from tensorflow.python.framework import importer as importer
from tensorflow.python.framework import ops as ops from tensorflow.python.framework import ops as ops
from tensorflow.python.ops import array_ops as aops from tensorflow.python.ops import array_ops as aops
from tensorflow.python.ops import math_ops as mops
from tensorflow.python.ops import nn as nn from tensorflow.python.ops import nn as nn
from tensorflow.python.ops import nn_ops as nn_ops from tensorflow.python.ops import nn_ops as nn_ops
def py2bytes(inp):
return inp
def py3bytes(inp):
return inp.encode("utf-8", errors="surrogateescape")
def py2string(inp):
return inp
def py3string(inp):
return inp.decode("utf-8")
if _six.PY2:
to_bytes = py2bytes
to_string = py2string
else:
to_bytes = py3bytes
to_string = py3string
def get_multi_engine_graph_def(mode="FP32"):
"""Create a simple graph and return its graph_def."""
dtype = dtypes.float32
if mode.upper() == "FP16":
dtype = dtypes.float16
else:
pass
g = ops.Graph()
with g.as_default():
x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype)
with g.name_scope("Global_scope"):
with g.name_scope("first_scope"):
e = cop.constant(
np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype)
conv = nn.conv2d(
input=x,
filter=e,
data_format="NCHW",
strides=[1, 1, 1, 1],
padding="VALID",
name="conv")
b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype)
t = conv * b
b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype)
q = conv / b
edge = mops.sin(q)
edge1 = mops.cos(conv)
with g.name_scope("test_scope"):
de = edge + edge1
t -= edge1
q *= edge
t += q
t -= de
k = aops.squeeze(t, name="output")
print(k.dtype)
return g.as_graph_def()
def get_simple_graph_def(): def get_simple_graph_def():
"""Create a simple graph and return its graph_def.""" """Create a simple graph and return its graph_def."""
g = ops.Graph() g = ops.Graph()
@ -65,7 +131,9 @@ def get_simple_graph_def():
def execute_graph(gdef, dumm_inp): def execute_graph(gdef, dumm_inp):
"""Run given graphdef once.""" """Run given graphdef once."""
print("executing") print("executing")
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) gpu_options = None
if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) sessconfig = cpb2.ConfigProto(gpu_options=gpu_options)
ops.reset_default_graph() ops.reset_default_graph()
g = ops.Graph() g = ops.Graph()
@ -83,7 +151,9 @@ def execute_graph(gdef, dumm_inp):
# for calibration. For this test script it is random data. # for calibration. For this test script it is random data.
def execute_calibration(gdef, dumm_inp): def execute_calibration(gdef, dumm_inp):
"""Run given calibration graph multiple times.""" """Run given calibration graph multiple times."""
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) gpu_options = None
if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
ops.reset_default_graph() ops.reset_default_graph()
g = ops.Graph() g = ops.Graph()
with g.as_default(): with g.as_default():
@ -100,12 +170,17 @@ def execute_calibration(gdef, dumm_inp):
return val return val
def user(run_graph=execute_graph, run_calibration=execute_calibration): def user(multi_engine,
run_graph=execute_graph,
run_calibration=execute_calibration):
"""Example function that converts a graph to TFTRT graph.""" """Example function that converts a graph to TFTRT graph."""
if multi_engine:
inp_dims = (100, 24, 24, 2) inp_dims = (2, 3, 7, 5)
orig_graph = get_multi_engine_graph_def()
else:
inp_dims = (100, 24, 24, 2)
orig_graph = get_simple_graph_def() # use a frozen graph for inference
dummy_input = np.random.random_sample(inp_dims) dummy_input = np.random.random_sample(inp_dims)
orig_graph = get_simple_graph_def() # use a frozen graph for inference
# Get optimized graph # Get optimized graph
trt_graph = trt.create_inference_graph( trt_graph = trt.create_inference_graph(
input_graph_def=orig_graph, input_graph_def=orig_graph,
@ -113,8 +188,10 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration):
max_batch_size=inp_dims[0], max_batch_size=inp_dims[0],
max_workspace_size_bytes=1 << 25, max_workspace_size_bytes=1 << 25,
precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8"
minimum_segment_size=2 # minimum number of nodes in an engine minimum_segment_size=2, # minimum number of nodes in an engine
) is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batches=[])
o1 = run_graph(orig_graph, dummy_input) o1 = run_graph(orig_graph, dummy_input)
o2 = run_graph(trt_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input)
o3 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input)
@ -126,40 +203,51 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration):
max_batch_size=inp_dims[0], max_batch_size=inp_dims[0],
max_workspace_size_bytes=1 << 25, max_workspace_size_bytes=1 << 25,
precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8"
minimum_segment_size=2 # minimum number of nodes in an engine minimum_segment_size=2, # minimum number of nodes in an engine
) is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batches=[])
int8_calib_gdef = trt.create_inference_graph( int8_calib_gdef = trt.create_inference_graph(
input_graph_def=orig_graph, input_graph_def=orig_graph,
outputs=["output"], outputs=["output"],
max_batch_size=inp_dims[0], max_batch_size=inp_dims[0],
max_workspace_size_bytes=1 << 25, max_workspace_size_bytes=1 << 25,
precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8"
minimum_segment_size=2 # minimum number of nodes in an engine minimum_segment_size=2, # minimum number of nodes in an engine
) is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batches=[])
o4 = run_graph(fp16_graph, dummy_input) o4 = run_graph(fp16_graph, dummy_input)
_ = run_calibration(int8_calib_gdef, dummy_input) _ = run_calibration(int8_calib_gdef, dummy_input)
int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
o5 = run_graph(int8_graph, dummy_input) o5 = run_graph(int8_graph, dummy_input)
assert np.allclose(o1, o4) print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4))
assert np.allclose(o1, o5) print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5))
print("Pass") print("Pass")
def auto(): def auto(multi_engine):
"""Run the conversion as an optimization pass.""" """Run the conversion as an optimization pass."""
inp_dims = (100, 24, 24, 2) if multi_engine:
inp_dims = (2, 3, 7, 5)
orig_graph = get_multi_engine_graph_def()
else:
inp_dims = (100, 24, 24, 2)
orig_graph = get_simple_graph_def() # use a frozen graph for inference
dummy_input = np.random.random_sample(inp_dims) dummy_input = np.random.random_sample(inp_dims)
orig_graph = get_simple_graph_def()
opt_config = rwpb2.RewriterConfig() opt_config = rwpb2.RewriterConfig()
opt_config.meta_optimizer_iterations = opt_config.ONE
opt_config.optimizers.extend(["constfold", "layout"]) opt_config.optimizers.extend(["constfold", "layout"])
custom_op = opt_config.custom_optimizers.add() custom_op = opt_config.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer" custom_op.name = "TensorRTOptimizer"
custom_op.parameter_map["minimum_segment_size"].i = 3 custom_op.parameter_map["minimum_segment_size"].i = 3
custom_op.parameter_map["precision_mode"].s = "FP32" custom_op.parameter_map["precision_mode"].s = to_bytes("FP32")
custom_op.parameter_map["max_batch_size"].i = inp_dims[0] custom_op.parameter_map["max_batch_size"].i = inp_dims[0]
custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
print(custom_op) print(custom_op)
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) gpu_options = None
if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
graph_options = cpb2.GraphOptions(rewrite_options=opt_config) graph_options = cpb2.GraphOptions(rewrite_options=opt_config)
sessconfig = cpb2.ConfigProto( sessconfig = cpb2.ConfigProto(
gpu_options=gpu_options, graph_options=graph_options) gpu_options=gpu_options, graph_options=graph_options)
@ -168,7 +256,7 @@ def auto():
ops.reset_default_graph() ops.reset_default_graph()
with g.as_default(): with g.as_default():
inp, out = importer.import_graph_def( inp, out = importer.import_graph_def(
graph_def=orig_graph, return_elements=["input", "output"]) graph_def=orig_graph, return_elements=["input", "output"], name="")
inp = inp.outputs[0] inp = inp.outputs[0]
out = out.outputs[0] out = out.outputs[0]
with csess.Session(config=sessconfig, graph=g) as sess: with csess.Session(config=sessconfig, graph=g) as sess:
@ -186,8 +274,14 @@ if "__main__" in __name__:
action="store_true", action="store_true",
help="Do TRT conversion automatically", help="Do TRT conversion automatically",
default=False) default=False)
P.add_argument(
"--multi-engine",
"-m",
action="store_true",
help="Use a graph that will result in 2 engines",
default=False)
flags, unparsed = P.parse_known_args() flags, unparsed = P.parse_known_args()
if flags.automatic: if flags.automatic:
auto() auto(flags.multi_engine)
else: else:
user() user(flags.multi_engine)

View File

@ -48,12 +48,53 @@ PyObject* pair_helper(std::pair<string, string>* in) {
} }
return tuple; return tuple;
} }
struct version_struct{
int vmajor;
int vminor;
int vpatch;
};
PyObject* version_helper(version_struct* in) {
PyObject *tuple(nullptr);
tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch);
if (!tuple) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_TypeError,
"Tuple creation from version structure failed!");
}
return NULL;
}
return tuple;
}
/* Define converters for vector<int> */
template<>
bool _PyObjAs(PyObject *pyobj, int* dest) {
*dest = PyLong_AsLong(pyobj);
return true;
}
template<>
PyObject *_PyObjFrom(const int& src) {
return PyLong_FromLong(src);
}
%} %}
_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
%typemap(out) std::pair<string, string> { %typemap(out) std::pair<string, string> {
PyObject *tuple = pair_helper(&$1); PyObject *tuple = pair_helper(&$1);
if (!tuple) SWIG_fail; if (!tuple) SWIG_fail;
$result = tuple; $result = tuple;
} }
%typemap(out) version_struct {
PyObject *tuple = version_helper(&$1);
if (!tuple) SWIG_fail;
$result = tuple;
}
%{ %{
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -65,6 +106,8 @@ PyObject* pair_helper(std::pair<string, string>* in) {
%unignore tensorflow; %unignore tensorflow;
%unignore trt_convert; %unignore trt_convert;
%unignore calib_convert; %unignore calib_convert;
%unignore get_linked_tensorrt_version;
%unignore get_loaded_tensorrt_version;
%{ %{
@ -74,7 +117,10 @@ std::pair<string, string> trt_convert(
size_t max_batch_size, size_t max_batch_size,
size_t max_workspace_size_bytes, size_t max_workspace_size_bytes,
int precision_mode, int precision_mode,
int minimum_segment_size int minimum_segment_size,
bool is_dyn_op,
int max_cached_engines,
std::vector<int> cached_engine_batches
// Unfortunately we can't use TF_Status here since it // Unfortunately we can't use TF_Status here since it
// is in c/c_api and brings in a lot of other libraries // is in c/c_api and brings in a lot of other libraries
// which in turn declare ops. These ops are included // which in turn declare ops. These ops are included
@ -102,11 +148,12 @@ std::pair<string, string> trt_convert(
out_status = "InvalidArgument;Size of the output_names vector is 0"; out_status = "InvalidArgument;Size of the output_names vector is 0";
return std::pair<string, string>{out_status, ""}; return std::pair<string, string>{out_status, ""};
} }
tensorflow::GraphDef outGraph; tensorflow::GraphDef out_graph;
tensorflow::Status conversion_status = tensorflow::Status conversion_status =
tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
graph_def, output_names, max_batch_size, max_workspace_size_bytes, graph_def, output_names, max_batch_size, max_workspace_size_bytes,
&outGraph, precision_mode, minimum_segment_size); &out_graph, precision_mode, minimum_segment_size,
is_dyn_op, max_cached_engines, cached_engine_batches);
if (!conversion_status.ok()) { if (!conversion_status.ok()) {
auto retCode = (int)conversion_status.code(); auto retCode = (int)conversion_status.code();
char buff[2000]; char buff[2000];
@ -116,7 +163,7 @@ std::pair<string, string> trt_convert(
return std::pair<string, string>{out_status, ""}; return std::pair<string, string>{out_status, ""};
} }
string result; string result;
if (!outGraph.SerializeToString(&result)) { if (!out_graph.SerializeToString(&result)) {
out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
return std::pair<string, string>{out_status, ""}; return std::pair<string, string>{out_status, ""};
} }
@ -128,7 +175,8 @@ std::pair<string, string> trt_convert(
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
} }
std::pair<string, string> calib_convert(string graph_def_string // const tensorflow::GraphDef& std::pair<string, string> calib_convert(
string graph_def_string, bool is_dyn_op
// unfortunately we can't use TF_Status here since it // unfortunately we can't use TF_Status here since it
// is in c/c_api and brings in a lot of other libraries // is in c/c_api and brings in a lot of other libraries
// which in turn declare ops. These ops are included // which in turn declare ops. These ops are included
@ -147,11 +195,11 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
return std::pair<string, string>{out_status, ""}; return std::pair<string, string>{out_status, ""};
} }
graph_def_string.resize(0);
tensorflow::GraphDef outGraph; tensorflow::GraphDef out_graph;
tensorflow::Status conversion_status = tensorflow::Status conversion_status =
tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def, tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(
&outGraph); graph_def, &out_graph, is_dyn_op);
if (!conversion_status.ok()) { if (!conversion_status.ok()) {
auto retCode = (int)conversion_status.code(); auto retCode = (int)conversion_status.code();
char buff[2000]; char buff[2000];
@ -161,7 +209,7 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
return std::pair<string, string>{out_status, ""}; return std::pair<string, string>{out_status, ""};
} }
string result; string result;
if (!outGraph.SerializeToString(&result)) { if (!out_graph.SerializeToString(&result)) {
out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
return std::pair<string, string>{out_status, ""}; return std::pair<string, string>{out_status, ""};
} }
@ -172,15 +220,39 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
return std::pair<string, string>{"9;TensorRT is not enabled!", ""}; return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
} }
version_struct get_linked_tensorrt_version(){
// Return the version at the link time.
const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion();
version_struct s;
s.vmajor = lv[0];
s.vminor = lv[1];
s.vpatch = lv[2];
return s;
}
version_struct get_loaded_tensorrt_version(){
// Return the version from the loaded library.
const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion();
version_struct s;
s.vmajor = lv[0];
s.vminor = lv[1];
s.vpatch = lv[2];
return s;
}
%} %}
std::pair<string, string> calib_convert(string graph_def_string); std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
std::pair<string, string> trt_convert(string graph_def_string, std::pair<string, string> trt_convert(string graph_def_string,
std::vector<string> output_names, std::vector<string> output_names,
size_t max_batch_size, size_t max_batch_size,
size_t max_workspace_size_bytes, size_t max_workspace_size_bytes,
int precision_mode, int minimum_segment_size); int precision_mode, int minimum_segment_size,
bool is_dyn_op,
int max_cached_engines,
std::vector<int> cached_engine_batches);
version_struct get_linked_tensorrt_version();
version_struct get_loaded_tensorrt_version();
%unignoreall %unignoreall

View File

@ -49,11 +49,11 @@ tf_cc_binary(
":tpu_profiler_analysis_proto_cc", ":tpu_profiler_analysis_proto_cc",
":tpu_profiler_proto_cc", ":tpu_profiler_proto_cc",
":version", ":version",
"//tensorflow:grpc++",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/platform/cloud:gcs_file_system", "//tensorflow/core/platform/cloud:gcs_file_system",
"@grpc//:grpc++",
], ],
) )

View File

@ -53,12 +53,12 @@ cc_library(
":grpc_verbs_service_impl", ":grpc_verbs_service_impl",
":rdma_mgr", ":rdma_mgr",
":verbs_service_proto_cc", ":verbs_service_proto_cc",
"//tensorflow:grpc++",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
"@grpc//:grpc++",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -69,7 +69,7 @@ cc_library(
hdrs = ["grpc_verbs_service_impl.h"], hdrs = ["grpc_verbs_service_impl.h"],
deps = [ deps = [
":verbs_service_proto_cc", ":verbs_service_proto_cc",
"@grpc//:grpc++", "//tensorflow:grpc++",
], ],
) )

View File

@ -4,6 +4,7 @@
# The following targets can be used to access ApiDefs: # The following targets can be used to access ApiDefs:
# :base_api_def # :base_api_def
# :python_api_def # :python_api_def
# :java_api_def
package( package(
default_visibility = ["//visibility:private"], default_visibility = ["//visibility:private"],
@ -29,6 +30,12 @@ filegroup(
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
) )
filegroup(
name = "java_api_def",
srcs = glob(["java_api/*"]),
visibility = ["//tensorflow:internal"],
)
cc_library( cc_library(
name = "excluded_ops_lib", name = "excluded_ops_lib",
srcs = ["excluded_ops.cc"], srcs = ["excluded_ops.cc"],

View File

@ -68,7 +68,7 @@ END
name: "area_range" name: "area_range"
description: <<END description: <<END
The cropped area of the image must contain a fraction of the The cropped area of the image must contain a fraction of the
supplied image within in this range. supplied image within this range.
END END
} }
attr { attr {

View File

@ -68,7 +68,7 @@ END
name: "area_range" name: "area_range"
description: <<END description: <<END
The cropped area of the image must contain a fraction of the The cropped area of the image must contain a fraction of the
supplied image within in this range. supplied image within this range.
END END
} }
attr { attr {

View File

@ -11,7 +11,7 @@ END
name: "stride" name: "stride"
description: <<END description: <<END
A scalar representing the steps moving the sliding window A scalar representing the steps moving the sliding window
forward in one iteration. It must be in `[1, window_size)`. forward in one iteration. It must be positive.
END END
} }
summary: "Creates a dataset that passes a sliding window over `input_dataset`." summary: "Creates a dataset that passes a sliding window over `input_dataset`."

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "Assert" #TODO(karllessard) escape that reserved name
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "Const" #TODO(karllessard) escape that reserved name
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "Switch" #TODO(karllessard) escape that reserved name
visibility: HIDDEN
}

View File

@ -106,24 +106,24 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(1, shape.dim(1).size()); EXPECT_EQ(1, shape.dim(1).size());
if (node->name() == y->name()) { if (node->name() == y->name()) {
#ifdef INTEL_MKL #ifdef INTEL_MKL
// if MKL is used, it goes through various additional // if MKL is used, it goes through various additional
// graph rewrite pass. In TF, everytime a graph pass // graph rewrite pass. In TF, everytime a graph pass
// happens, "constant" nodes are allocated // happens, "constant" nodes are allocated
// and deallocated. Each allocation calls the // and deallocated. Each allocation calls the
// (FindChunkPtr of BFCAllocator), // (FindChunkPtr of BFCAllocator),
// which increments the value of AllocationId. // which increments the value of AllocationId.
// Thus AllocationId becomes more than 3 and 4 if // Thus AllocationId becomes more than TF if MKL
// MKL is used. Now they are 9 and 10 for MKL. // is used. Now IDs for MKL are 8 more than TF.
EXPECT_EQ(19, cm->AllocationId(node, 0)); EXPECT_EQ(29, cm->AllocationId(node, 0));
#else #else
EXPECT_EQ(21, cm->AllocationId(node, 0)); EXPECT_EQ(21, cm->AllocationId(node, 0));
#endif #endif
} else { } else {
#ifdef INTEL_MKL #ifdef INTEL_MKL
EXPECT_EQ(20, cm->AllocationId(node, 0)); EXPECT_EQ(30, cm->AllocationId(node, 0));
#else #else
EXPECT_EQ(22, cm->AllocationId(node, 0)); EXPECT_EQ(22, cm->AllocationId(node, 0));
#endif #endif
} }
} }
EXPECT_LE(0, cm->MaxExecutionTime(node)); EXPECT_LE(0, cm->MaxExecutionTime(node));

View File

@ -17,6 +17,13 @@ limitations under the License.
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h" #include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
#ifdef _WIN32
// Declare function to avoid unresolved symbol in VS
i_malloc_t i_malloc;
i_calloc_t i_calloc;
i_realloc_t i_realloc;
i_free_t i_free;
#endif
namespace tensorflow { namespace tensorflow {
constexpr const char* MklCPUAllocator::kMaxLimitStr; constexpr const char* MklCPUAllocator::kMaxLimitStr;

View File

@ -143,6 +143,7 @@ tf_cuda_library(
":debug_node_key", ":debug_node_key",
":debug_service_proto_cc", ":debug_service_proto_cc",
":debugger_event_metadata_proto_cc", ":debugger_event_metadata_proto_cc",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:graph", "//tensorflow/core:graph",
@ -150,7 +151,6 @@ tf_cuda_library(
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:proto_text", "//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@grpc//:grpc++",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -166,11 +166,11 @@ tf_cuda_library(
":debug_io_utils", ":debug_io_utils",
":debug_service_proto_cc", ":debug_service_proto_cc",
":debugger_event_metadata_proto_cc", ":debugger_event_metadata_proto_cc",
"//tensorflow:grpc++",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@grpc//:grpc++",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -628,6 +628,7 @@ tf_cuda_cc_test(
":master", ":master",
":remote_device", ":remote_device",
":worker_interface", ":worker_interface",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -649,7 +650,6 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:dense_update_ops",
"//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels:variable_ops",
"@grpc//:grpc++",
], ],
) )
@ -667,6 +667,7 @@ tf_cuda_cc_test(
":master", ":master",
":remote_device", ":remote_device",
":worker_interface", ":worker_interface",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -682,7 +683,6 @@ tf_cuda_cc_test(
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib", "//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"@grpc//:grpc++",
], ],
) )

View File

@ -48,6 +48,8 @@ cc_library(
"eager_service_impl.h", "eager_service_impl.h",
], ],
deps = [ deps = [
"//tensorflow:grpc",
"//tensorflow:grpc++",
"//tensorflow/c:c_api_internal", "//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
@ -67,8 +69,6 @@ cc_library(
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"@grpc",
"@grpc//:grpc++",
], ],
) )

View File

@ -41,8 +41,8 @@ cc_library(
srcs = ["grpc_util.cc"], srcs = ["grpc_util.cc"],
hdrs = ["grpc_util.h"], hdrs = ["grpc_util.h"],
deps = [ deps = [
"@grpc", "//tensorflow:grpc",
"@grpc//:grpc++", "//tensorflow:grpc++",
"//tensorflow/core:lib", "//tensorflow/core:lib",
# Required to be able to overload TensorResponse parsing. # Required to be able to overload TensorResponse parsing.
"//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:tensor_coding",
@ -55,8 +55,8 @@ cc_library(
hdrs = ["grpc_client_cq_tag.h"], hdrs = ["grpc_client_cq_tag.h"],
deps = [ deps = [
":grpc_util", ":grpc_util",
"//tensorflow:grpc++",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@grpc//:grpc++",
], ],
) )
@ -67,10 +67,10 @@ cc_library(
deps = [ deps = [
":grpc_client_cq_tag", ":grpc_client_cq_tag",
":grpc_util", ":grpc_util",
"//tensorflow:grpc++",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:tensor_coding",
"@grpc//:grpc++",
], ],
) )
@ -83,6 +83,7 @@ cc_library(
":grpc_state", ":grpc_state",
":grpc_util", ":grpc_util",
":grpc_worker_service_impl", ":grpc_worker_service_impl",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
@ -90,7 +91,6 @@ cc_library(
"//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_cache_logger", "//tensorflow/core/distributed_runtime:worker_cache_logger",
"//tensorflow/core/distributed_runtime:worker_interface", "//tensorflow/core/distributed_runtime:worker_interface",
"@grpc//:grpc++",
], ],
) )
@ -100,10 +100,10 @@ cc_library(
hdrs = ["grpc_channel.h"], hdrs = ["grpc_channel.h"],
deps = [ deps = [
":grpc_util", ":grpc_util",
"//tensorflow:grpc++",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"@grpc//:grpc++",
], ],
) )
@ -112,13 +112,13 @@ cc_library(
srcs = ["grpc_tensor_coding.cc"], srcs = ["grpc_tensor_coding.cc"],
hdrs = ["grpc_tensor_coding.h"], hdrs = ["grpc_tensor_coding.h"],
deps = [ deps = [
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc", "//tensorflow/core:worker_proto_cc",
"@grpc//:grpc++",
], ],
) )
@ -127,9 +127,9 @@ cc_library(
srcs = [], srcs = [],
hdrs = ["grpc_call.h"], hdrs = ["grpc_call.h"],
deps = [ deps = [
"//tensorflow:grpc++",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"@grpc//:grpc++",
], ],
) )
@ -167,6 +167,7 @@ tf_cuda_library(
":grpc_tensor_coding", ":grpc_tensor_coding",
":grpc_util", ":grpc_util",
":grpc_worker_service_impl", ":grpc_worker_service_impl",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -180,7 +181,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:worker_cache", "//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_session", "//tensorflow/core/distributed_runtime:worker_session",
"@grpc//:grpc++",
], ],
) )
@ -190,9 +190,9 @@ cc_library(
hdrs = ["grpc_worker_service_impl.h"], hdrs = ["grpc_worker_service_impl.h"],
deps = [ deps = [
":grpc_util", ":grpc_util",
"//tensorflow:grpc++",
"//tensorflow/core:worker_proto_cc", "//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:tensor_coding",
"@grpc//:grpc++",
], ],
) )
@ -220,12 +220,12 @@ cc_library(
":async_service_interface", ":async_service_interface",
":grpc_call", ":grpc_call",
":grpc_util", ":grpc_util",
"//tensorflow:grpc++",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc", "//tensorflow/core:master_proto_cc",
"//tensorflow/core:master_service_proto_cc", "//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:master", "//tensorflow/core/distributed_runtime:master",
"@grpc//:grpc++",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -259,6 +259,8 @@ cc_library(
":grpc_worker_cache", ":grpc_worker_cache",
":grpc_worker_service", ":grpc_worker_service",
":rpc_rendezvous_mgr", ":rpc_rendezvous_mgr",
"//tensorflow:grpc",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -277,8 +279,6 @@ cc_library(
"//tensorflow/core/distributed_runtime:worker_cache_wrapper", "//tensorflow/core/distributed_runtime:worker_cache_wrapper",
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl",
"@grpc",
"@grpc//:grpc++",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -299,13 +299,13 @@ tf_cc_binary(
], ],
deps = [ deps = [
":grpc_server_lib", ":grpc_server_lib",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:data_flow", "//tensorflow/core/kernels:data_flow",
"@grpc//:grpc++",
], ],
) )
@ -317,6 +317,7 @@ tf_cc_binary(
], ],
deps = [ deps = [
":grpc_server_lib", ":grpc_server_lib",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -330,7 +331,6 @@ tf_cc_binary(
"//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:reduction_ops", "//tensorflow/core/kernels:reduction_ops",
"//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels:variable_ops",
"@grpc//:grpc++",
], ],
) )
@ -415,6 +415,7 @@ tf_cc_test(
deps = [ deps = [
":grpc_tensor_coding", ":grpc_tensor_coding",
":grpc_testlib", ":grpc_testlib",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -424,7 +425,6 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"//tensorflow/core:worker_proto_cc", "//tensorflow/core:worker_proto_cc",
"@grpc//:grpc++",
], ],
) )
@ -434,11 +434,11 @@ tf_cc_test(
srcs = ["grpc_util_test.cc"], srcs = ["grpc_util_test.cc"],
deps = [ deps = [
":grpc_util", ":grpc_util",
"//tensorflow:grpc",
"//tensorflow:grpc++",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:worker_proto_cc", "//tensorflow/core:worker_proto_cc",
"@grpc",
"@grpc//:grpc++",
], ],
) )

View File

@ -11,8 +11,8 @@ cc_library(
srcs = ["grpc_eager_service.cc"], srcs = ["grpc_eager_service.cc"],
hdrs = ["grpc_eager_service.h"], hdrs = ["grpc_eager_service.h"],
deps = [ deps = [
"//tensorflow:grpc++",
"//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:eager_service_proto_cc",
"@grpc//:grpc++",
], ],
) )
@ -21,6 +21,7 @@ cc_library(
srcs = ["grpc_eager_client.cc"], srcs = ["grpc_eager_client.cc"],
hdrs = ["grpc_eager_client.h"], hdrs = ["grpc_eager_client.h"],
deps = [ deps = [
"//tensorflow:grpc++",
"//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:eager_client",
@ -29,7 +30,6 @@ cc_library(
"//tensorflow/core/distributed_runtime/rpc:grpc_state", "//tensorflow/core/distributed_runtime/rpc:grpc_state",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service",
"@grpc//:grpc++",
], ],
) )
@ -39,6 +39,7 @@ cc_library(
hdrs = ["grpc_eager_service_impl.h"], hdrs = ["grpc_eager_service_impl.h"],
deps = [ deps = [
":grpc_eager_service", ":grpc_eager_service",
"//tensorflow:grpc++",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:ptr_util", "//tensorflow/core:ptr_util",
"//tensorflow/core/distributed_runtime/eager:eager_service_impl", "//tensorflow/core/distributed_runtime/eager:eager_service_impl",
@ -47,6 +48,6 @@ cc_library(
"//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"@grpc//:grpc++", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
], ],
) )

View File

@ -289,6 +289,12 @@ Status GrpcServer::Init(
nullptr); nullptr);
} }
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
return Init(std::move(service_func), rendezvous_mgr_func, nullptr, nullptr);
}
Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); } Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); }
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,

View File

@ -97,6 +97,9 @@ class GrpcServer : public ServerInterface {
const RendezvousMgrCreationFunction& rendezvous_mgr_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func,
const CollectiveMgrCreationFunction& collective_mgr_func); const CollectiveMgrCreationFunction& collective_mgr_func);
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
Status Init(); Status Init();
// A subclass can override this method to support secure credentials. // A subclass can override this method to support secure credentials.

View File

@ -1901,6 +1901,11 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
#else // INTEL_MKL_ML #else // INTEL_MKL_ML
// NOTE: Unit tests in this file rely on a topological sorted graph for
// printing. But since sibling nodes of a node in the topologically sorted graph
// can be printed in different orders, tests may fail if the order in which
// sibling nodes are visited is changed.
namespace { namespace {
const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0"; const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0";
@ -2572,9 +2577,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;" "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;" "A:control->DMT/_0:control;A:control->DMT/_1:control;"
"B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;" "B->E:1;C->F;C:control->DMT/_2:control;C:control->DMT/_3:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" "D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;"
"DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;" "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;"
"G:control->DMT/_4:control;H->I:1"); "G:control->DMT/_4:control;H->I:1");
} }
@ -2681,9 +2686,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;" "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;" "A:control->DMT/_0:control;A:control->DMT/_1:control;B->E:1;C->F;"
"C:control->DMT/_0:control;C:control->DMT/_1:control;" "C:control->DMT/_2:control;C:control->DMT/_3:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" "D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;"
"DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;" "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;"
"F:2->H:4;G->H:2;H->I:1"); "F:2->H:4;G->H:2;H->I:1");
} }
@ -3060,8 +3065,8 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
"C:control->DMT/_1:control;C:control->DMT/_2:control;" "C:control->DMT/_1:control;C:control->DMT/_2:control;"
"C:control->DMT/_3:control;C:control->DMT/_4:control;" "C:control->DMT/_3:control;C:control->DMT/_4:control;"
"C:control->DMT/_5:control;C:control->DMT/_6:control;" "C:control->DMT/_5:control;C:control->DMT/_6:control;"
"D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;" "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;DMT/_3->F:3;"
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1"); "DMT/_4->F:7;DMT/_5->F:4;DMT/_6->F:6;E->G;F->G:1");
} }
/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */ /* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/batch_util.h" #include "tensorflow/core/util/batch_util.h"
namespace tensorflow { namespace tensorflow {
@ -32,16 +33,24 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input, void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override { DatasetBase** output) override {
int64 window_size = 0; int64 window_size = 0;
int64 stride = 1; int64 stride = 0;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size)); ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride)); OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride));
OP_REQUIRES( OP_REQUIRES(
ctx, window_size > 0, ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero.")); errors::InvalidArgument("Window size must be greater than zero."));
OP_REQUIRES( OP_REQUIRES(ctx, stride > 0,
ctx, stride > 0 && stride < window_size, errors::InvalidArgument("Stride must be greater than zero."));
errors::InvalidArgument("Stride must be in [1, window_size).")); if (stride == window_size) {
LOG(WARNING) << "stride: " << stride
<< " is equal to window_size: " << window_size
<< ", to use `batch` instead.";
} else if (stride > window_size) {
LOG(WARNING) << "stride: " << stride
<< " is greater than window_size: " << window_size
<< ", you will lose some data.";
}
*output = new Dataset(ctx, window_size, stride, input); *output = new Dataset(ctx, window_size, stride, input);
} }
@ -124,12 +133,15 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
return Status::OK(); return Status::OK();
} }
batch_elements.reserve(window_size); batch_elements.reserve(window_size);
const bool first_call = cache_.empty(); // Use cache if stride < window_size.
if (first_call) { if (stride < window_size) {
cache_.reserve(window_size); const bool first_call = cache_.empty();
} else { if (first_call) {
// Reuse cache in the previous iteration. cache_.reserve(window_size);
cache_.swap(batch_elements); } else {
// Reuse cache in the previous iteration.
cache_.swap(batch_elements);
}
} }
// Fill up with new elements. // Fill up with new elements.
*end_of_sequence = false; *end_of_sequence = false;
@ -149,9 +161,22 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
DCHECK(*end_of_sequence); DCHECK(*end_of_sequence);
return Status::OK(); return Status::OK();
} }
// Cache the data used for the next iteration.
for (size_t i = stride; i < window_size; ++i) { if (stride < window_size) {
cache_.emplace_back(batch_elements[i]); // Cache the data used for the next iteration.
for (size_t i = stride; i < window_size; ++i) {
cache_.emplace_back(batch_elements[i]);
}
} else if (stride > window_size) {
// Drop the data before the next iteration.
std::vector<Tensor> batch_element_tuple;
for (size_t i = window_size; i < stride && !*end_of_sequence; ++i) {
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
end_of_sequence));
if (*end_of_sequence) {
input_impl_.reset();
}
}
} }
} }

View File

@ -59,7 +59,8 @@ namespace tensorflow {
#ifndef INTEL_MKL_ML #ifndef INTEL_MKL_ML
struct ConvFwdDimensions { // This structure aggregates multiple inputs to Conv2DFwd* methods.
struct MklConvFwdParams {
memory::dims src_dims; memory::dims src_dims;
memory::dims filter_dims; memory::dims filter_dims;
memory::dims bias_dims; memory::dims bias_dims;
@ -69,48 +70,56 @@ struct ConvFwdDimensions {
memory::dims padding_left; memory::dims padding_left;
memory::dims padding_right; memory::dims padding_right;
ConvFwdDimensions(memory::dims src_dims, MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
memory::dims filter_dims, memory::dims bias_dims, memory::dims bias_dims, memory::dims dst_dims,
memory::dims dst_dims, memory::dims strides, memory::dims strides, memory::dims dilations,
memory::dims dilations, memory::dims padding_left, memory::dims padding_left, memory::dims padding_right)
memory::dims padding_right) : : src_dims(src_dims),
src_dims(src_dims), filter_dims(filter_dims), filter_dims(filter_dims),
bias_dims(bias_dims), dst_dims(dst_dims), bias_dims(bias_dims),
strides(strides), dilations(dilations), dst_dims(dst_dims),
padding_left(padding_left), padding_right(padding_right) { strides(strides),
} dilations(dilations),
padding_left(padding_left),
padding_right(padding_right) {}
}; };
template <typename T> template <typename T>
class Conv2DFwd : public DnnOp { class MklConv2DFwdPrimitive : public MklPrimitive {
public: public:
explicit Conv2DFwd(const ConvFwdDimensions& convFwdDims) { explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
fwd_stream_.reset(new stream(stream::kind::eager)); : cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// create conv primitive // create conv primitive
if (conv_fwd_ == nullptr) { if (context_.conv_fwd == nullptr) {
Setup(convFwdDims); Setup(convFwdDims);
} }
} }
~Conv2DFwd() {} ~MklConv2DFwdPrimitive() {}
// Convolution forward execute with bias // Convolution forward execute with bias
// src_data: input data buffer of src // src_data: input data buffer of src
// filter_data: input data buffer of filter (weights) // filter_data: input data buffer of filter (weights)
// bias_data: input data buffer of bias // bias_data: input data buffer of bias
// dst_data: output data buffer of dst // dst_data: output data buffer of dst
void Execute(T* src_data, T* filter_data, T* bias_data, T* dst_data) { void Execute(const T* src_data, const T* filter_data, const T* bias_data,
src_mem_->set_data_handle(static_cast<void*>(src_data)); const T* dst_data) {
filter_mem_->set_data_handle(static_cast<void*>(filter_data)); context_.src_mem->set_data_handle(
bias_mem_->set_data_handle(static_cast<void*>(bias_data)); static_cast<void*>(const_cast<T*>(src_data)));
dst_mem_->set_data_handle(static_cast<void*>(dst_data)); context_.filter_mem->set_data_handle(
fwd_stream_->submit(fwd_primitives_); static_cast<void*>(const_cast<T*>(filter_data)));
context_.bias_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(bias_data)));
context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(dst_data)));
context_.fwd_stream->submit(context_.fwd_primitives);
// after exec, set data handle back // after exec, set data handle back
src_mem_->set_data_handle(DummyData); context_.src_mem->set_data_handle(DummyData);
filter_mem_->set_data_handle(DummyData); context_.filter_mem->set_data_handle(DummyData);
bias_mem_->set_data_handle(DummyData); context_.bias_mem->set_data_handle(DummyData);
dst_mem_->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData);
return; return;
} }
@ -119,139 +128,177 @@ class Conv2DFwd : public DnnOp {
// src_data: input data buffer of src // src_data: input data buffer of src
// filter_data: input data buffer of filter (weights) // filter_data: input data buffer of filter (weights)
// dst_data: output data buffer of dst // dst_data: output data buffer of dst
void Execute(T* src_data, T* filter_data, T* dst_data) { void Execute(const T* src_data, const T* filter_data, const T* dst_data) {
src_mem_->set_data_handle(static_cast<void*>(src_data)); context_.src_mem->set_data_handle(
filter_mem_->set_data_handle(static_cast<void*>(filter_data)); static_cast<void*>(const_cast<T*>(src_data)));
dst_mem_->set_data_handle(static_cast<void*>(dst_data)); context_.filter_mem->set_data_handle(
fwd_stream_->submit(fwd_primitives_); static_cast<void*>(const_cast<T*>(filter_data)));
context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(dst_data)));
context_.fwd_stream->submit(context_.fwd_primitives);
// after exec, set data handle back // after execution, set data handle back
src_mem_->set_data_handle(DummyData); context_.src_mem->set_data_handle(DummyData);
filter_mem_->set_data_handle(DummyData); context_.filter_mem->set_data_handle(DummyData);
dst_mem_->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData);
return;
} }
// expected memory format for this primitive instance memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format src_fmt_;
memory::format filter_fmt_;
// convolution primitive memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::primitive> conv_fwd_; std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
GetPrimitiveDesc() const {
return context_.fwd_pd;
}
private: private:
void Setup(const ConvFwdDimensions& convFwdDims) { // Primitive reuse context for Conv2D Fwd op
struct ConvFwdContext {
// expected memory format for this primitive instance
memory::format src_fmt;
memory::format filter_fmt;
// MKLDNN memory
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> filter_mem;
std::shared_ptr<mkldnn::memory> bias_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
// desc & prmitive desc
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
// memory desc
std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> filter_md;
std::shared_ptr<mkldnn::memory::desc> bias_md;
std::shared_ptr<mkldnn::memory::desc> dst_md;
// convolution primitive
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
std::shared_ptr<mkldnn::primitive> conv_fwd;
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
ConvFwdContext()
: src_fmt(memory::format::any),
filter_fmt(memory::format::any),
src_mem(nullptr),
filter_mem(nullptr),
bias_mem(nullptr),
dst_mem(nullptr),
fwd_desc(nullptr),
src_md(nullptr),
filter_md(nullptr),
bias_md(nullptr),
fwd_pd(nullptr),
conv_fwd(nullptr),
fwd_stream(nullptr) {}
};
void Setup(const MklConvFwdParams& convFwdDims) {
// create memory descriptors for convolution data w/ no specified format // create memory descriptors for convolution data w/ no specified format
src_md_.reset(new memory::desc({convFwdDims.src_dims}, context_.src_md.reset(new memory::desc(
MklDnnType<T>(), memory::format::any)); {convFwdDims.src_dims}, MklDnnType<T>(), memory::format::any));
filter_md_.reset(new memory::desc({convFwdDims.filter_dims}, context_.filter_md.reset(new memory::desc(
MklDnnType<T>(), memory::format::any)); {convFwdDims.filter_dims}, MklDnnType<T>(), memory::format::any));
dst_md_.reset(new memory::desc({convFwdDims.dst_dims}, context_.dst_md.reset(new memory::desc(
MklDnnType<T>(), memory::format::any)); {convFwdDims.dst_dims}, MklDnnType<T>(), memory::format::any));
if (!convFwdDims.bias_dims.empty()) if (!convFwdDims.bias_dims.empty())
bias_md_.reset(new memory::desc({convFwdDims.bias_dims}, context_.bias_md.reset(new memory::desc(
MklDnnType<T>(), memory::format::any)); {convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::any));
// create a convolution // create a convolution
if (!convFwdDims.bias_dims.empty()) { if (!convFwdDims.bias_dims.empty()) {
fwd_desc_.reset(new convolution_forward::desc(prop_kind::forward, context_.fwd_desc.reset(new convolution_forward::desc(
convolution_direct, *src_md_, *filter_md_, *bias_md_, *dst_md_, prop_kind::forward, convolution_direct, *context_.src_md,
*context_.filter_md, *context_.bias_md, *context_.dst_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right, padding_kind::zero)); convFwdDims.padding_right, padding_kind::zero));
} else { } else {
fwd_desc_.reset(new convolution_forward::desc(prop_kind::forward, context_.fwd_desc.reset(new convolution_forward::desc(
convolution_direct, *src_md_, *filter_md_, *dst_md_, prop_kind::forward, convolution_direct, *context_.src_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, *context_.filter_md, *context_.dst_md, convFwdDims.strides,
convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right, padding_kind::zero)); convFwdDims.padding_right, padding_kind::zero));
} }
fwd_pd_.reset(new convolution_forward::primitive_desc( context_.fwd_pd.reset(new convolution_forward::primitive_desc(
*fwd_desc_, cpu_engine_)); *context_.fwd_desc, cpu_engine_));
// store the expected memory format // store the expected memory format
src_fmt_ = static_cast<mkldnn::memory::format>( context_.src_fmt = static_cast<mkldnn::memory::format>(
fwd_pd_.get()->src_primitive_desc().desc().data.format); context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
filter_fmt_ = static_cast<mkldnn::memory::format>( context_.filter_fmt = static_cast<mkldnn::memory::format>(
fwd_pd_.get()->weights_primitive_desc().desc().data.format); context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
// create memory primitive based on dummy data // create memory primitive based on dummy data
src_mem_.reset(new memory(fwd_pd_.get()->src_primitive_desc(), DummyData)); context_.src_mem.reset(
filter_mem_.reset(new memory(fwd_pd_.get()->weights_primitive_desc(), new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
DummyData)); context_.filter_mem.reset(
dst_mem_.reset(new memory(fwd_pd_.get()->dst_primitive_desc(), DummyData)); new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
// create convolution primitive and add it to net // create convolution primitive and add it to net
if (!convFwdDims.bias_dims.empty()) { if (!convFwdDims.bias_dims.empty()) {
bias_mem_.reset(new memory({{{convFwdDims.bias_dims}, MklDnnType<T>(), context_.bias_mem.reset(new memory(
memory::format::x}, cpu_engine_}, DummyData)); {{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x},
conv_fwd_.reset(new convolution_forward(*fwd_pd_, *src_mem_, cpu_engine_},
*filter_mem_, *bias_mem_, *dst_mem_)); DummyData));
context_.conv_fwd.reset(new convolution_forward(
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
*context_.bias_mem, *context_.dst_mem));
} else { } else {
conv_fwd_.reset(new convolution_forward(*fwd_pd_, *src_mem_, context_.conv_fwd.reset(
*filter_mem_, *dst_mem_)); new convolution_forward(*context_.fwd_pd, *context_.src_mem,
*context_.filter_mem, *context_.dst_mem));
} }
fwd_primitives_.push_back(*conv_fwd_); context_.fwd_primitives.push_back(*context_.conv_fwd);
return; return;
} }
// MKLDNN memory struct ConvFwdContext context_;
std::shared_ptr<mkldnn::memory> src_mem_; engine cpu_engine_;
std::shared_ptr<mkldnn::memory> filter_mem_;
std::shared_ptr<mkldnn::memory> bias_mem_;
std::shared_ptr<mkldnn::memory> dst_mem_;
std::shared_ptr<mkldnn::stream> fwd_stream_;
std::vector<mkldnn::primitive> fwd_primitives_;
// desc & prmitive desc
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc_;
// memory desc
std::shared_ptr<mkldnn::memory::desc> src_md_;
std::shared_ptr<mkldnn::memory::desc> filter_md_;
std::shared_ptr<mkldnn::memory::desc> bias_md_;
std::shared_ptr<mkldnn::memory::desc> dst_md_;
engine cpu_engine_ = engine(engine::cpu, 0);
}; };
template <typename T> template <typename T>
class Conv2DFwdFactory : public DnnOpFactory<T> { class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public: public:
static Conv2DFwd<T>* Get(const ConvFwdDimensions& convFwdDims) { static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
Conv2DFwd<T>* conv2d_fwd = nullptr; MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
// try to find a suitable one in pool // try to find a suitable one in pool
conv2d_fwd = dynamic_cast<Conv2DFwd<T>*> ( conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
Conv2DFwdFactory<T>::GetInstance().GetConv2DFwd(convFwdDims)); MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
convFwdDims));
if (conv2d_fwd == nullptr) { if (conv2d_fwd == nullptr) {
conv2d_fwd = new Conv2DFwd<T>(convFwdDims); conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
Conv2DFwdFactory<T>::GetInstance().SetConv2DFwd( MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
convFwdDims, conv2d_fwd); conv2d_fwd);
} }
return conv2d_fwd; return conv2d_fwd;
} }
private: private:
Conv2DFwdFactory() {} MklConv2DFwdPrimitiveFactory() {}
~Conv2DFwdFactory() {} ~MklConv2DFwdPrimitiveFactory() {}
static const int kDilationH = 0, kDilationW = 1; static const int kDilationH = 0, kDilationW = 1;
static Conv2DFwdFactory& GetInstance() { static MklConv2DFwdPrimitiveFactory& GetInstance() {
static Conv2DFwdFactory instance_; static MklConv2DFwdPrimitiveFactory instance_;
return instance_; return instance_;
} }
static std::string CreateKey(const ConvFwdDimensions& convFwdDims) { static std::string CreateKey(const MklConvFwdParams& convFwdDims) {
std::string prefix = "conv2d_fwd_"; std::string prefix = "conv2d_fwd_";
FactoryKeyCreator key_creator; FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix); key_creator.AddAsKey(prefix);
@ -266,12 +313,12 @@ class Conv2DFwdFactory : public DnnOpFactory<T> {
return key_creator.GetKey(); return key_creator.GetKey();
} }
DnnOp* GetConv2DFwd(const ConvFwdDimensions& convFwdDims) { MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
std::string key = CreateKey(convFwdDims); std::string key = CreateKey(convFwdDims);
return this->GetOp(key); return this->GetOp(key);
} }
void SetConv2DFwd(const ConvFwdDimensions& convFwdDims, DnnOp *op) { void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
std::string key = CreateKey(convFwdDims); std::string key = CreateKey(convFwdDims);
this->SetOp(key, op); this->SetOp(key, op);
} }
@ -762,7 +809,6 @@ class MklConv2DOp : public OpKernel {
MklDnnData<T> src(&cpu_engine); MklDnnData<T> src(&cpu_engine);
MklDnnData<T> filter(&cpu_engine); MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> dst(&cpu_engine); // output
memory::dims src_dims, filter_dims, padding_left, padding_right, memory::dims src_dims, filter_dims, padding_left, padding_right,
dilations, strides; dilations, strides;
@ -812,7 +858,6 @@ class MklConv2DOp : public OpKernel {
auto src_md = src_mkl_shape.IsMklTensor() auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout() ? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), tf_fmt); : memory::desc(src_dims, MklDnnType<T>(), tf_fmt);
src.SetUsrMem(src_md, &src_tensor);
// Although filter shape (filter_dims) required is in MKL-DNN order, // Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO). // the layout is Tensorflow's layout (HWIO).
@ -820,29 +865,30 @@ class MklConv2DOp : public OpKernel {
? filter_mkl_shape.GetMklLayout() ? filter_mkl_shape.GetMklLayout()
: memory::desc(filter_dims, MklDnnType<T>(), : memory::desc(filter_dims, MklDnnType<T>(),
memory::format::hwio); memory::format::hwio);
filter.SetUsrMem(filter_md, &filter_tensor);
// MKLDNN dilation starts from 0. // MKLDNN dilation starts from 0.
dilations[kDilationH] -= 1; dilations[kDilationH] -= 1;
dilations[kDilationW] -= 1; dilations[kDilationW] -= 1;
// get a conv2d fwd from primitive pool // get a conv2d fwd from primitive pool
Conv2DFwd<T> *conv2d_fwd = nullptr; MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
if (biasEnabled) { if (biasEnabled) {
memory::dims bias_dims = {}; memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
ConvFwdDimensions convFwdDims(src_dims, filter_dims, bias_dims, MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
dst_dims_mkl_order, strides, dilations, padding_left, padding_right); dst_dims_mkl_order, strides, dilations,
conv2d_fwd = Conv2DFwdFactory<T>::Get(convFwdDims); padding_left, padding_right);
conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
} else { } else {
ConvFwdDimensions convFwdDims(src_dims, filter_dims, NONE_DIMS, MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
dst_dims_mkl_order, strides, dilations, padding_left, padding_right); dst_dims_mkl_order, strides, dilations,
conv2d_fwd = Conv2DFwdFactory<T>::Get(convFwdDims); padding_left, padding_right);
conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
} }
// allocate output tensors output_tensor and filter_out_tensor // allocate output tensors output_tensor and filter_out_tensor
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
conv_fwd_pd = conv2d_fwd->fwd_pd_; conv2d_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *conv_fwd_pd, AllocateOutputTensor(context, *conv_fwd_pd,
dst_dims_mkl_order, tf_fmt, &dst_tensor); dst_dims_mkl_order, tf_fmt, &dst_tensor);
Tensor* filter_out_tensor = nullptr; Tensor* filter_out_tensor = nullptr;
@ -854,20 +900,28 @@ class MklConv2DOp : public OpKernel {
// check whether src/filter need reorder // check whether src/filter need reorder
std::vector<primitive> net; std::vector<primitive> net;
if (src_md.data.format != conv2d_fwd->src_fmt_) T* src_data = nullptr;
src.CheckReorderToOpMem( if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
conv_fwd_pd.get()->src_primitive_desc(), &net); src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc(), &net);
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
T* filter_data = nullptr;
if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
filter.GetTensorBuffer(filter_out_tensor),
&net);
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
} else {
filter_data =
static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
}
if (filter_md.data.format != conv2d_fwd->filter_fmt_)
filter.CheckReorderToOpMem(
conv_fwd_pd.get()->weights_primitive_desc(),
filter.GetTensorBuffer(filter_out_tensor), &net);
stream(stream::kind::eager).submit(net).wait(); stream(stream::kind::eager).submit(net).wait();
T* src_data = static_cast<T*>(
src.GetOpMem().get_data_handle());
T* filter_data = static_cast<T*>(
filter.GetOpMem().get_data_handle());
// execute convolution // execute convolution
if (biasEnabled) { if (biasEnabled) {

View File

@ -295,7 +295,7 @@ __global__ void ColumnReduceMax16ColumnsKernel(
// 1D array necessary due to bug in CUDA 9 compiler. // 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready. // TODO(nluehr) revert to 2D array when compiler is ready.
// This is the mimic the following, but without any constructors: // This is to mimic the following, but without any constructors:
// __shared__ storage_type<value_type> partial_sums[32 * 33]; // __shared__ storage_type<value_type> partial_sums[32 * 33];
__shared__ __align__( __shared__ __align__(
alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)]; alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)];

View File

@ -16,6 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
// Unfortunately we can't add the #include, since it breaks compilation for
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
// This file requires the following include because it uses CudaAtomicMax: // This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h" // #include "tensorflow/core/util/cuda_kernel_helper.h"

View File

@ -614,7 +614,13 @@ REGISTER_OP("ApproximateEqual")
.SetIsCommutative() .SetIsCommutative()
.Attr("T: numbertype") .Attr("T: numbertype")
.Attr("tolerance: float = 0.00001") .Attr("tolerance: float = 0.00001")
.SetShapeFn(shape_inference::UnchangedShape); .SetShapeFn([](InferenceContext* c) {
// The inputs 'x' and 'y' must have the same shape.
ShapeHandle data_x = c->input(0);
ShapeHandle data_y = c->input(1);
TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
return shape_inference::UnchangedShape(c);
});
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------

View File

@ -137,8 +137,8 @@ Status EncodeJwtClaim(StringPiece client_email, StringPiece scope,
const auto expiration_timestamp_sec = const auto expiration_timestamp_sec =
request_timestamp_sec + kRequestedTokenLifetimeSec; request_timestamp_sec + kRequestedTokenLifetimeSec;
root["iat"] = request_timestamp_sec; root["iat"] = Json::Value::UInt64(request_timestamp_sec);
root["exp"] = expiration_timestamp_sec; root["exp"] = Json::Value::UInt64(expiration_timestamp_sec);
// Step 2: represent the JSON as a string. // Step 2: represent the JSON as a string.
string claim = root.toStyledString(); string claim = root.toStyledString();

View File

@ -202,7 +202,10 @@ def cc_proto_library(
) )
if use_grpc_plugin: if use_grpc_plugin:
cc_libs += ["//external:grpc_lib"] cc_libs += select({
"//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
"//conditions:default": ["//external:grpc_lib"],
})
if default_header: if default_header:
header_only_name = name header_only_name = name

View File

@ -171,5 +171,10 @@ int64 AvailableRam() {
return INT64_MAX; return INT64_MAX;
} }
int NumHyperthreadsPerCore() {
static const int ht_per_core = tensorflow::port::CPUIDNumSMT();
return (ht_per_core > 0) ? ht_per_core : 1;
}
} // namespace port } // namespace port
} // namespace tensorflow } // namespace tensorflow

View File

@ -47,9 +47,9 @@ Json::Value ChromeTraceFormatter::CreateEvent(const string& ph,
event["ph"] = Json::Value(ph); event["ph"] = Json::Value(ph);
event["cat"] = Json::Value(category); event["cat"] = Json::Value(category);
event["name"] = Json::Value(name); event["name"] = Json::Value(name);
event["pid"] = Json::Value(pid); event["pid"] = Json::Int64(pid);
event["tid"] = Json::Value(tid); event["tid"] = Json::Int64(tid);
event["ts"] = Json::Value(ts); event["ts"] = Json::Int64(ts);
return event; return event;
} }
@ -57,7 +57,7 @@ void ChromeTraceFormatter::EmitPID(const string& name, int64 pid) {
Json::Value event(Json::objectValue); Json::Value event(Json::objectValue);
event["name"] = Json::Value("process_name"); event["name"] = Json::Value("process_name");
event["ph"] = Json::Value("M"); event["ph"] = Json::Value("M");
event["pid"] = Json::Value(pid); event["pid"] = Json::Int64(pid);
Json::Value args(Json::objectValue); Json::Value args(Json::objectValue);
args["name"] = Json::Value(name); args["name"] = Json::Value(name);
event["args"] = args; event["args"] = args;
@ -68,7 +68,7 @@ void ChromeTraceFormatter::EmitRegion(int64 ts, int64 duration, int64 pid,
int64 tid, const string& category, int64 tid, const string& category,
const string& name, Json::Value args) { const string& name, Json::Value args) {
Json::Value event = CreateEvent("X", category, name, pid, tid, ts); Json::Value event = CreateEvent("X", category, name, pid, tid, ts);
event["dur"] = Json::Value(duration); event["dur"] = Json::Int64(duration);
event["args"] = std::move(args); event["args"] = std::move(args);
metadata_.push_back(event); metadata_.push_back(event);
} }
@ -76,14 +76,14 @@ void ChromeTraceFormatter::EmitRegion(int64 ts, int64 duration, int64 pid,
void ChromeTraceFormatter::EmitFlowStart(const string& name, int64 ts, void ChromeTraceFormatter::EmitFlowStart(const string& name, int64 ts,
int64 pid, int64 tid, int64 flow_id) { int64 pid, int64 tid, int64 flow_id) {
Json::Value event = CreateEvent("s", "DataFlow", name, pid, tid, ts); Json::Value event = CreateEvent("s", "DataFlow", name, pid, tid, ts);
event["id"] = flow_id; event["id"] = Json::Int64(flow_id);
events_.push_back(event); events_.push_back(event);
} }
void ChromeTraceFormatter::EmitFlowEnd(const string& name, int64 ts, int64 pid, void ChromeTraceFormatter::EmitFlowEnd(const string& name, int64 ts, int64 pid,
int64 tid, int64 flow_id) { int64 tid, int64 flow_id) {
Json::Value event = CreateEvent("t", "DataFlow", name, pid, tid, ts); Json::Value event = CreateEvent("t", "DataFlow", name, pid, tid, ts);
event["id"] = flow_id; event["id"] = Json::Int64(flow_id);
events_.push_back(event); events_.push_back(event);
} }
@ -93,7 +93,7 @@ void ChromeTraceFormatter::EmitCounter(
const std::map<int64, std::vector<string>>& tensor_mem) { const std::map<int64, std::vector<string>>& tensor_mem) {
Json::Value event = CreateEvent("C", category, "Allocated Bytes", pid, 0, ts); Json::Value event = CreateEvent("C", category, "Allocated Bytes", pid, 0, ts);
Json::Value args(Json::objectValue); Json::Value args(Json::objectValue);
args["Allocator Bytes in Use"] = Json::Value(bytes); args["Allocator Bytes in Use"] = Json::Int64(bytes);
event["args"] = args; event["args"] = args;
events_.push_back(event); events_.push_back(event);

View File

@ -1814,11 +1814,11 @@ class MklDnnData {
} }
}; };
/// Base class for operations with reuse of DNN primitives /// Base class for operations with reuse of primitives
/// ///
class DnnOp { class MklPrimitive {
public: public:
virtual ~DnnOp() {} virtual ~MklPrimitive() {}
// Dummy data. Its size, hard-coded as 256 here, does // Dummy data. Its size, hard-coded as 256 here, does
// not matter since MKL should never operate on this buffer. // not matter since MKL should never operate on this buffer.
@ -1826,33 +1826,33 @@ class DnnOp {
}; };
const mkldnn::memory::dims NONE_DIMS = {}; const mkldnn::memory::dims NONE_DIMS = {};
// This constant is used to declare dummy buffer (size), for MKL primitives
template <typename T>
class DnnOpFactory {
public:
DnnOpFactory() {}
~DnnOpFactory() {}
DnnOp* GetOp(const std::string& key) { template <typename T>
auto stream_iter = DnnOpFactory<T>::GetHashMap().find(key); class MklPrimitiveFactory {
if (stream_iter == DnnOpFactory<T>::GetHashMap().end()) { public:
MklPrimitiveFactory() {}
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const std::string& key) {
auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
if (stream_iter == MklPrimitiveFactory<T>::GetHashMap().end()) {
return nullptr; return nullptr;
} else { } else {
return stream_iter->second; return stream_iter->second;
} }
} }
void SetOp(const std::string& key, DnnOp* op) { void SetOp(const std::string& key, MklPrimitive* op) {
auto stream_iter = DnnOpFactory<T>::GetHashMap().find(key); auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
CHECK(stream_iter == DnnOpFactory<T>::GetHashMap().end()); CHECK(stream_iter == MklPrimitiveFactory<T>::GetHashMap().end());
DnnOpFactory<T>::GetHashMap()[key] = op; MklPrimitiveFactory<T>::GetHashMap()[key] = op;
} }
private: private:
static inline std::unordered_map<std::string, DnnOp*> &GetHashMap() { static inline std::unordered_map<std::string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<std::string, DnnOp*> map_; static thread_local std::unordered_map<std::string, MklPrimitive*> map_;
return map_; return map_;
} }
}; };

View File

@ -0,0 +1,29 @@
# Get Started
If you are new to machine learning, we recommend taking the following online
course prior to diving into TensorFlow documentation:
* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/),
which introduces machine learning concepts and encourages experimentation
with existing TensorFlow code.
TensorFlow is a tool for machine learning. While it contains a wide range of
functionality, TensorFlow is mainly designed for deep neural network models.
The easiest way to get started with TensorFlow is by using Eager Execution.
* @{$get_started/eager}, is for anyone new to machine learning or TensorFlow.
TensorFlow provides many APIs. The remainder of this section focuses on the
Estimator API which provide scalable, high-performance models. See the
@{$estimators} guide.
For more advanced users:
* The @{$low_level_intro$Low Level Introduction} demonstrates how to use
TensorFlow outside of the Estimator framework, for debugging and
experimentation.
* The @{$guide$Programmer's Guide} details major
TensorFlow components.
* The @{$tutorials$Tutorials} provide walkthroughs of a variety of
TensorFlow models.

View File

@ -17,7 +17,7 @@ how to use the graphical user interface (GUI) of tfdbg, i.e., the
Note: The TensorFlow debugger uses a Note: The TensorFlow debugger uses a
[curses](https://en.wikipedia.org/wiki/Curses_\(programming_library\))-based text [curses](https://en.wikipedia.org/wiki/Curses_\(programming_library\))-based text
user interface. On Mac OS X, the `ncurses` library is required and can be user interface. On Mac OS X, the `ncurses` library is required and can be
installed with `brew install homebrew/dupes/ncurses`. On Windows, curses isn't as installed with `brew install ncurses`. On Windows, curses isn't as
well supported, so a [readline](https://en.wikipedia.org/wiki/GNU_Readline)-based well supported, so a [readline](https://en.wikipedia.org/wiki/GNU_Readline)-based
interface can be used with tfdbg by installing `pyreadline` with `pip`. If you interface can be used with tfdbg by installing `pyreadline` with `pip`. If you
use Anaconda3, you can install it with a command such as use Anaconda3, you can install it with a command such as

245
tensorflow/go/attrs.go Normal file
View File

@ -0,0 +1,245 @@
/*
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package tensorflow
// #include <stdlib.h>
// #include "tensorflow/c/c_api.h"
import "C"
import (
"fmt"
"unsafe"
)
// makeCShape converts a shape specified in C.int64_t into a Shape.
func makeCShape(shape []C.int64_t) Shape {
s := Shape{dims: make([]int64, len(shape))}
for i, n := range shape {
s.dims[i] = int64(n)
}
return s
}
// Attr returns the value of an attribute on op. It returns an error if the
// attribute does not exist.
func (op *Operation) Attr(name string) (interface{}, error) {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
status := newStatus()
meta := C.TF_OperationGetAttrMetadata(op.c, cname, status.c)
if err := status.Err(); err != nil {
return nil, err
}
if meta.is_list == 1 {
return listAttribute(op, cname, meta)
}
return scalarAttribute(op, cname, meta)
}
func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) {
status := newStatus()
switch meta._type {
case C.TF_ATTR_STRING:
if meta.list_size == 0 {
return []string(nil), nil
}
values := make([]unsafe.Pointer, meta.list_size)
lengths := make([]C.size_t, meta.list_size)
// Add one element in case total_size is zero.
storage := make([]C.char, meta.total_size+1)
C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c)
if err := status.Err(); err != nil {
return nil, err
}
list := make([]string, meta.list_size)
for i, val := range values {
length := lengths[i]
list[i] = C.GoStringN((*C.char)(val), C.int(length))
}
return list, nil
case C.TF_ATTR_INT:
if meta.list_size == 0 {
return []int64(nil), nil
}
list := make([]C.int64_t, meta.list_size)
C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
return nil, err
}
vals := make([]int64, meta.list_size)
for i, val := range list {
vals[i] = int64(val)
}
return vals, nil
case C.TF_ATTR_FLOAT:
if meta.list_size == 0 {
return []float32(nil), nil
}
list := make([]C.float, meta.list_size)
C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
return nil, err
}
vals := make([]float32, meta.list_size)
for i, val := range list {
vals[i] = float32(val)
}
return vals, nil
case C.TF_ATTR_BOOL:
if meta.list_size == 0 {
return []bool(nil), nil
}
list := make([]C.uchar, meta.list_size)
C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
return nil, err
}
vals := make([]bool, meta.list_size)
for i, val := range list {
vals[i] = val == 1
}
return vals, nil
case C.TF_ATTR_TYPE:
if meta.list_size == 0 {
return []DataType(nil), nil
}
list := make([]C.TF_DataType, meta.list_size)
C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
return nil, err
}
vals := make([]DataType, meta.list_size)
for i, val := range list {
vals[i] = DataType(val)
}
return vals, nil
case C.TF_ATTR_TENSOR:
if meta.list_size == 0 {
return []*Tensor(nil), nil
}
list := make([]*C.TF_Tensor, meta.list_size)
C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
return nil, err
}
vals := make([]*Tensor, meta.list_size)
for i, t := range list {
vals[i] = newTensorFromC(t)
}
return vals, nil
case C.TF_ATTR_SHAPE:
if meta.list_size == 0 {
return []Shape(nil), nil
}
dims := make([]*C.int64_t, meta.list_size)
numDims := make([]C.int, meta.list_size)
// Add one element in case total_size is zero.
storage := make([]C.int64_t, meta.total_size+1)
C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c)
if err := status.Err(); err != nil {
return nil, err
}
list := make([]Shape, meta.list_size)
for i, dim := range dims {
numDim := numDims[i]
// If the number of dimensions is unknown, default to empty shape.
if numDim < 0 {
continue
}
// A []C.int64_t slice backed by C memory.
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
slice := (*[1 << 30]C.int64_t)(unsafe.Pointer(dim))[:numDim:numDim]
list[i] = makeCShape(slice)
}
return list, nil
default:
return nil, fmt.Errorf("list type %v not supported", meta._type)
}
}
func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) {
status := newStatus()
switch meta._type {
case C.TF_ATTR_STRING:
if meta.total_size == 0 {
return "", nil
}
v := make([]C.char, meta.total_size)
C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c)
if err := status.Err(); err != nil {
return nil, err
}
return C.GoStringN(&v[0], C.int(meta.total_size)), nil
case C.TF_ATTR_INT:
var v C.int64_t
C.TF_OperationGetAttrInt(op.c, cname, &v, status.c)
return int64(v), status.Err()
case C.TF_ATTR_FLOAT:
var v C.float
C.TF_OperationGetAttrFloat(op.c, cname, &v, status.c)
return float32(v), status.Err()
case C.TF_ATTR_BOOL:
var v C.uchar
C.TF_OperationGetAttrBool(op.c, cname, &v, status.c)
return v == 1, status.Err()
case C.TF_ATTR_TYPE:
var v C.TF_DataType
C.TF_OperationGetAttrType(op.c, cname, &v, status.c)
return DataType(v), status.Err()
case C.TF_ATTR_TENSOR:
var v *C.TF_Tensor
C.TF_OperationGetAttrTensor(op.c, cname, &v, status.c)
if err := status.Err(); err != nil {
return nil, err
}
return newTensorFromC(v), nil
case C.TF_ATTR_SHAPE:
numDims := meta.total_size
// If number of dims is unknown return empty shape to indicate that.
if numDims < 0 {
return Shape{}, nil
}
if numDims == 0 {
return ScalarShape(), nil
}
dims := make([]C.int64_t, numDims)
C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c)
if err := status.Err(); err != nil {
return nil, err
}
return makeCShape(dims), nil
default:
return nil, fmt.Errorf("type %v not supported", meta._type)
}
}

193
tensorflow/go/attrs_test.go Normal file
View File

@ -0,0 +1,193 @@
/*
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package tensorflow
import (
"fmt"
"reflect"
"testing"
)
func TestOperationAttrs(t *testing.T) {
g := NewGraph()
i := 0
makeConst := func(v interface{}) Output {
op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v)
i++
if err != nil {
t.Fatal(err)
}
return op
}
makeTensor := func(v interface{}) *Tensor {
tensor, err := NewTensor(v)
if err != nil {
t.Fatal(err)
}
return tensor
}
cases := []OpSpec{
{
Name: "type",
Type: "Placeholder",
Attrs: map[string]interface{}{
"dtype": Float,
},
},
{
Name: "list(float)",
Type: "Bucketize",
Input: []Input{
makeConst([]float32{1, 2, 3, 4}),
},
Attrs: map[string]interface{}{
"boundaries": []float32{0, 1, 2, 3, 4, 5},
},
},
{
Name: "list(float) empty",
Type: "Bucketize",
Input: []Input{
makeConst([]float32{}),
},
Attrs: map[string]interface{}{
"boundaries": []float32(nil),
},
},
/* TODO(ashankar): debug this issue and add it back later.
{
Name: "list(type),list(shape)",
Type: "InfeedEnqueueTuple",
Input: []Input{
OutputList([]Output{
makeConst(float32(1)),
makeConst([][]int32{{2}}),
}),
},
Attrs: map[string]interface{}{
"dtypes": []DataType{Float, Int32},
"shapes": []Shape{ScalarShape(), MakeShape(1, 1)},
},
},
{
Name: "list(type),list(shape) empty",
Type: "InfeedEnqueueTuple",
Input: []Input{
OutputList([]Output{
makeConst([][]int32{{2}}),
}),
},
Attrs: map[string]interface{}{
"dtypes": []DataType{Int32},
"shapes": []Shape(nil),
},
},
{
Name: "list(type) empty,string empty,int",
Type: "_XlaSendFromHost",
Input: []Input{
OutputList([]Output{}),
makeConst(""),
},
Attrs: map[string]interface{}{
"Tinputs": []DataType(nil),
"key": "",
"device_ordinal": int64(0),
},
},
*/
{
Name: "list(int),int",
Type: "StringToHashBucketStrong",
Input: []Input{
makeConst(""),
},
Attrs: map[string]interface{}{
"num_buckets": int64(2),
"key": []int64{1, 2},
},
},
{
Name: "list(int) empty,int",
Type: "StringToHashBucketStrong",
Input: []Input{
makeConst(""),
},
Attrs: map[string]interface{}{
"num_buckets": int64(2),
"key": ([]int64)(nil),
},
},
{
Name: "list(string),type",
Type: "TensorSummary",
Input: []Input{
makeConst(""),
},
Attrs: map[string]interface{}{
"T": String,
"labels": []string{"foo", "bar"},
},
},
{
Name: "list(string) empty,type",
Type: "TensorSummary",
Input: []Input{
makeConst(""),
},
Attrs: map[string]interface{}{
"T": String,
"labels": ([]string)(nil),
},
},
{
Name: "tensor",
Type: "Const",
Attrs: map[string]interface{}{
"dtype": String,
"value": makeTensor("foo"),
},
},
}
for i, spec := range cases {
op, err := g.AddOperation(spec)
if err != nil {
t.Fatal(err)
}
for key, want := range spec.Attrs {
out, err := op.Attr(key)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(out, want) {
t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want)
}
wantT, ok := want.(*Tensor)
if ok {
wantVal := wantT.Value()
outVal := out.(*Tensor).Value()
if !reflect.DeepEqual(outVal, wantVal) {
t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal)
}
}
}
}
}

View File

@ -11210,7 +11210,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted
// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value. // SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
// //
// value: The cropped area of the image must contain a fraction of the // value: The cropped area of the image must contain a fraction of the
// supplied image within in this range. // supplied image within this range.
// If not specified, defaults to <f:0.05 f:1 > // If not specified, defaults to <f:0.05 f:1 >
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr { func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) { return func(m optionalAttr) {
@ -17969,9 +17969,10 @@ func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_val
} }
// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` // Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
//
// if < 0, `scale * features` otherwise. // if < 0, `scale * features` otherwise.
// //
// Assumes weights to have zero mean and variance 1.0 / fan_in.
//
// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) // See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
func Selu(scope *Scope, features tf.Output) (activations tf.Output) { func Selu(scope *Scope, features tf.Output) (activations tf.Output) {
if scope.Err() != nil { if scope.Err() != nil {
@ -21655,7 +21656,7 @@ func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. // generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
// //
// The `bad_color` argument is the color to use in the generated images for // The `bad_color` argument is the color to use in the generated images for
// non-finite input values. It is a `unit8` 1-D tensor of length `channels`. // non-finite input values. It is a `uint8` 1-D tensor of length `channels`.
// Each element must be in the range `[0, 255]` (It represents the value of a // Each element must be in the range `[0, 255]` (It represents the value of a
// pixel in the output image). Non-finite values in the input tensor are // pixel in the output image). Non-finite values in the input tensor are
// replaced by this tensor in the output image. The default value is the color // replaced by this tensor in the output image. The default value is the color
@ -24048,7 +24049,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort
// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value. // SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value.
// //
// value: The cropped area of the image must contain a fraction of the // value: The cropped area of the image must contain a fraction of the
// supplied image within in this range. // supplied image within this range.
// If not specified, defaults to <f:0.05 f:1 > // If not specified, defaults to <f:0.05 f:1 >
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr { func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) { return func(m optionalAttr) {

View File

@ -65,6 +65,11 @@ func (op *Operation) Output(i int) Output {
return Output{op, i} return Output{op, i}
} }
// NumInputs returns the number of inputs of op.
func (op *Operation) NumInputs() int {
return int(C.TF_OperationNumInputs(op.c))
}
// Output represents one of the outputs of an operation in the graph. Has a // Output represents one of the outputs of an operation in the graph. Has a
// DataType (and eventually a Shape). May be passed as an input argument to a // DataType (and eventually a Shape). May be passed as an input argument to a
// function for adding operations to a graph, or to a Session's Run() method to // function for adding operations to a graph, or to a Session's Run() method to
@ -123,6 +128,67 @@ func (p Output) c() C.TF_Output {
func (p Output) canBeAnInput() {} func (p Output) canBeAnInput() {}
// Consumers returns the inputs that consume this output.
func (p Output) Consumers() []Consumer {
max := int(C.TF_OperationOutputNumConsumers(p.c()))
if max == 0 {
return nil
}
inputs := make([]C.TF_Input, max)
n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max))
inputs = inputs[:int(n)]
var consumers []Consumer
for _, consumer := range inputs {
consumers = append(consumers, Consumer{
Index: int(consumer.index),
Op: &Operation{
c: consumer.oper,
g: p.Op.g,
},
})
}
return consumers
}
// Consumer identifies a specific input of an operation that consumes the output
// of another operation.
type Consumer struct {
// Op is the Operation that is consuming the output of another operation.
Op *Operation
// Index is the index of the input within Op that the output of another
// operation is connected to.
Index int
}
func (p Consumer) c() C.TF_Input {
if p.Op == nil {
// Attempt to provide a more useful panic message than "nil
// pointer dereference".
panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers")
}
return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)}
}
// DataType returns the type of the input.
func (p Consumer) DataType() DataType {
return DataType(C.TF_OperationInputType(p.c()))
}
// Producer returns the Output that is connected to this Consumer.
func (p Consumer) Producer() Output {
output := C.TF_OperationInput(p.c())
return Output{
Op: &Operation{
c: output.oper,
g: p.Op.g,
},
Index: int(output.index),
}
}
// Input is the interface for specifying inputs to an operation being added to // Input is the interface for specifying inputs to an operation being added to
// a Graph. // a Graph.
// //

View File

@ -166,6 +166,68 @@ func TestOutputDataTypeAndShape(t *testing.T) {
} }
} }
func TestOperationInputs(t *testing.T) {
g := NewGraph()
x, err := Placeholder(g, "x", Float)
if err != nil {
t.Fatal(err)
}
y, err := Placeholder(g, "y", Float)
if err != nil {
t.Fatal(err)
}
add, err := Add(g, "add", x, y)
if err != nil {
t.Fatal(err)
}
addOp := add.Op
if out := addOp.NumInputs(); out != 2 {
t.Fatalf("Got %d inputs, wanted 2", out)
}
}
func TestOperationConsumers(t *testing.T) {
g := NewGraph()
x, err := Placeholder(g, "x", Float)
if err != nil {
t.Fatal(err)
}
a, err := Neg(g, "a", x)
if err != nil {
t.Fatal(err)
}
b, err := Neg(g, "b", x)
if err != nil {
t.Fatal(err)
}
consumers := []*Operation{a.Op, b.Op}
xConsumers := x.Consumers()
if out := len(xConsumers); out != 2 {
t.Fatalf("Got %d consumers, wanted 2", out)
}
for i, consumer := range xConsumers {
got := consumer.Op.Name()
want := consumers[i].Name()
if got != want {
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
}
got = consumer.Producer().Op.Name()
want = x.Op.Name()
if got != want {
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
}
}
if len(b.Consumers()) != 0 {
t.Fatalf("expected %+v to have no consumers", b)
}
}
func forceGC() { func forceGC() {
var mem runtime.MemStats var mem runtime.MemStats
runtime.ReadMemStats(&mem) runtime.ReadMemStats(&mem)

View File

@ -56,6 +56,10 @@ java_library(
srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]), srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]),
javacopts = JAVACOPTS, javacopts = JAVACOPTS,
resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]), resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]),
deps = [
"@com_google_guava",
"@com_squareup_javapoet",
],
) )
filegroup( filegroup(
@ -70,6 +74,7 @@ tf_java_op_gen_srcjar(
name = "java_op_gen_sources", name = "java_op_gen_sources",
api_def_srcs = [ api_def_srcs = [
"//tensorflow/core/api_def:base_api_def", "//tensorflow/core/api_def:base_api_def",
"//tensorflow/core/api_def:java_api_def",
], ],
base_package = "org.tensorflow.op", base_package = "org.tensorflow.op",
gen_tool = ":java_op_gen_tool", gen_tool = ":java_op_gen_tool",

View File

@ -11,4 +11,10 @@ tensorflow/src
tensorflow/target tensorflow/target
proto/src proto/src
proto/target proto/target
hadoop/src
hadoop/target
spark-connector/src
spark-connector/target
spark-connector/dependency-reduced-pom.xml
spark-connector/spark-warehouse
pom.xml.versionsBackup pom.xml.versionsBackup

View File

@ -53,6 +53,12 @@ There are seven artifacts and thus `pom.xml`s involved in this release:
7. [`parentpom`](https://maven.apache.org/pom/index.html): Common settings 7. [`parentpom`](https://maven.apache.org/pom/index.html): Common settings
shared by all of the above. shared by all of the above.
8. `hadoop`: The TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop.
The source code for this package is available in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/hadoop)
9. `spark-connector`: A Scala library for loading and storing TensorFlow TFRecord
using Apache Spark DataFrames. The source code for this package is available
in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector)
## Updating the release ## Updating the release

View File

@ -0,0 +1,24 @@
<project
xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<!-- Placeholder pom which is replaced by TensorFlow ecosystem Hadoop pom during build -->
<modelVersion>4.0.0</modelVersion>
<description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
<artifactId>hadoop</artifactId>
<packaging>jar</packaging>
<scm>
<url>https://github.com/tensorflow/ecosystem.git</url>
<connection>git@github.com:tensorflow/ecosystem.git</connection>
<developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
</scm>
<url>https://github.com/tensorflow/ecosystem/</url>
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
<version>1.9.0-rc0</version>
<relativePath>../</relativePath>
</parent>
</project>

View File

@ -32,6 +32,8 @@
<module>libtensorflow_jni_gpu</module> <module>libtensorflow_jni_gpu</module>
<module>tensorflow</module> <module>tensorflow</module>
<module>proto</module> <module>proto</module>
<module>hadoop</module>
<module>spark-connector</module>
</modules> </modules>
<!-- Two profiles are used: <!-- Two profiles are used:

View File

@ -19,6 +19,7 @@
RELEASE_URL_PREFIX="https://storage.googleapis.com/tensorflow/libtensorflow" RELEASE_URL_PREFIX="https://storage.googleapis.com/tensorflow/libtensorflow"
TF_ECOSYSTEM_URL="https://github.com/tensorflow/ecosystem.git"
# By default we deploy to both ossrh and bintray. These two # By default we deploy to both ossrh and bintray. These two
# environment variables can be set to skip either repository. # environment variables can be set to skip either repository.
@ -44,7 +45,9 @@ clean() {
# (though if run inside a clean docker container, there won't be any dirty # (though if run inside a clean docker container, there won't be any dirty
# artifacts lying around) # artifacts lying around)
mvn -q clean mvn -q clean
rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target libtensorflow/src libtensorflow/target tensorflow-android/target rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target \
libtensorflow/src libtensorflow/target tensorflow-android/target proto/src proto/target \
hadoop/src hadoop/target spark-connector/src spark-connector/target
} }
update_version_in_pom() { update_version_in_pom() {
@ -183,6 +186,43 @@ generate_java_protos() {
rm -rf "${DIR}/proto/tmp" rm -rf "${DIR}/proto/tmp"
} }
# Download the TensorFlow ecosystem source from git.
# The pom files from this repo do not inherit from the parent pom so the maven version
# is updated for each module.
download_tf_ecosystem() {
ECOSYSTEM_DIR="/tmp/tensorflow-ecosystem"
HADOOP_DIR="${DIR}/hadoop"
SPARK_DIR="${DIR}/spark-connector"
# Clean any previous attempts
rm -rf "${ECOSYSTEM_DIR}"
# Clone the TensorFlow ecosystem project
mkdir -p "${ECOSYSTEM_DIR}"
cd "${ECOSYSTEM_DIR}"
git clone "${TF_ECOSYSTEM_URL}"
cd ecosystem
git checkout r${TF_VERSION}
# Copy the TensorFlow Hadoop source
cp -r "${ECOSYSTEM_DIR}/ecosystem/hadoop/src" "${HADOOP_DIR}"
cp "${ECOSYSTEM_DIR}/ecosystem/hadoop/pom.xml" "${HADOOP_DIR}"
cd "${HADOOP_DIR}"
update_version_in_pom
# Copy the TensorFlow Spark connector source
cp -r "${ECOSYSTEM_DIR}/ecosystem/spark/spark-tensorflow-connector/src" "${SPARK_DIR}"
cp "${ECOSYSTEM_DIR}/ecosystem/spark/spark-tensorflow-connector/pom.xml" "${SPARK_DIR}"
cd "${SPARK_DIR}"
update_version_in_pom
# Cleanup
rm -rf "${ECOSYSTEM_DIR}"
cd "${DIR}"
}
# Deploy artifacts using a specific profile. # Deploy artifacts using a specific profile.
# Arguments: # Arguments:
# profile - name of selected profile. # profile - name of selected profile.
@ -240,7 +280,8 @@ cd "${DIR}"
# Comment lines out appropriately if debugging/tinkering with the release # Comment lines out appropriately if debugging/tinkering with the release
# process. # process.
# gnupg2 is required for signing # gnupg2 is required for signing
apt-get -qq update && apt-get -qqq install -y gnupg2 apt-get -qq update && apt-get -qqq install -y gnupg2 git
clean clean
update_version_in_pom update_version_in_pom
download_libtensorflow download_libtensorflow
@ -248,6 +289,8 @@ download_libtensorflow_jni
download_libtensorflow_jni_gpu download_libtensorflow_jni_gpu
update_tensorflow_android update_tensorflow_android
generate_java_protos generate_java_protos
download_tf_ecosystem
# Build the release artifacts # Build the release artifacts
mvn verify mvn verify
# Push artifacts to repository # Push artifacts to repository

View File

@ -0,0 +1,24 @@
<project
xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<!-- Placeholder pom which is replaced by TensorFlow ecosystem Spark pom during build -->
<modelVersion>4.0.0</modelVersion>
<description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
<artifactId>spark-connector</artifactId>
<packaging>jar</packaging>
<scm>
<url>https://github.com/tensorflow/ecosystem.git</url>
<connection>git@github.com:tensorflow/ecosystem.git</connection>
<developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
</scm>
<url>https://github.com/tensorflow/ecosystem/</url>
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
<version>1.9.0-rc0</version>
<relativePath>../</relativePath>
</parent>
</project>

View File

@ -35,7 +35,7 @@ namespace tensorflow {
namespace java { namespace java {
namespace { namespace {
const char* kLicense = constexpr const char kLicense[] =
"/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
"\n" "\n"
"Licensed under the Apache License, Version 2.0 (the \"License\");\n" "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
@ -391,9 +391,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
} }
if (!op.hidden()) { if (!op.hidden()) {
// expose the op in the Ops Graph API only if it is visible // expose the op in the Ops Graph API only if it is visible
op_class.add_annotation( Annotation oper_annot =
Annotation::Create("Operator", "org.tensorflow.op.annotation") Annotation::Create("Operator", "org.tensorflow.op.annotation");
.attributes("group = \"" + endpoint.package() + "\"")); if (endpoint.package() != kDefaultEndpointPackage) {
oper_annot.attributes("group = \"" + endpoint.package() + "\"");
}
op_class.add_annotation(oper_annot);
} }
// create op class file // create op class file
const string op_dir_name = io::JoinPath( const string op_dir_name = io::JoinPath(

View File

@ -27,6 +27,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace java { namespace java {
constexpr const char kDefaultEndpointPackage[] = "core";
class EndpointSpec { class EndpointSpec {
public: public:
// A specification for an operation endpoint // A specification for an operation endpoint

Some files were not shown because too many files have changed in this diff Show More