Merge changes from github.
PiperOrigin-RevId: 202585094
This commit is contained in:
parent
3cee10e61c
commit
1e7b0e4ad6
1
.gitignore
vendored
1
.gitignore
vendored
@ -16,6 +16,7 @@ __pycache__
|
||||
cmake_build/
|
||||
.idea/**
|
||||
/build/
|
||||
[Bb]uild/
|
||||
/tensorflow/core/util/version_info.cc
|
||||
/tensorflow/python/framework/fast_tensor_util.cpp
|
||||
Pods
|
||||
|
72
configure.py
72
configure.py
@ -943,6 +943,35 @@ def set_tf_cudnn_version(environ_cp):
|
||||
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
|
||||
|
||||
|
||||
def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
|
||||
"""Check compatibility between given library and cudnn/cudart libraries."""
|
||||
ldd_bin = which('ldd') or '/usr/bin/ldd'
|
||||
ldd_out = run_shell([ldd_bin, lib], True)
|
||||
ldd_out = ldd_out.split(os.linesep)
|
||||
cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
|
||||
cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
|
||||
cudnn = None
|
||||
cudart = None
|
||||
cudnn_ok = True # assume no cudnn dependency by default
|
||||
cuda_ok = True # assume no cuda dependency by default
|
||||
for line in ldd_out:
|
||||
if 'libcudnn.so' in line:
|
||||
cudnn = cudnn_pattern.search(line)
|
||||
cudnn_ok = False
|
||||
elif 'libcudart.so' in line:
|
||||
cudart = cuda_pattern.search(line)
|
||||
cuda_ok = False
|
||||
if cudnn and len(cudnn.group(1)):
|
||||
cudnn = convert_version_to_int(cudnn.group(1))
|
||||
if cudart and len(cudart.group(1)):
|
||||
cudart = convert_version_to_int(cudart.group(1))
|
||||
if cudnn is not None:
|
||||
cudnn_ok = (cudnn == cudnn_ver)
|
||||
if cudart is not None:
|
||||
cuda_ok = (cudart == cuda_ver)
|
||||
return cudnn_ok and cuda_ok
|
||||
|
||||
|
||||
def set_tf_tensorrt_install_path(environ_cp):
|
||||
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
|
||||
|
||||
@ -959,8 +988,8 @@ def set_tf_tensorrt_install_path(environ_cp):
|
||||
raise ValueError('Currently TensorRT is only supported on Linux platform.')
|
||||
|
||||
# Ask user whether to add TensorRT support.
|
||||
if str(int(get_var(
|
||||
environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1':
|
||||
if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT',
|
||||
False))) != '1':
|
||||
return
|
||||
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
@ -973,47 +1002,29 @@ def set_tf_tensorrt_install_path(environ_cp):
|
||||
|
||||
# Result returned from "read" will be used unexpanded. That make "~"
|
||||
# unusable. Going through one more level of expansion to handle that.
|
||||
trt_install_path = os.path.realpath(
|
||||
os.path.expanduser(trt_install_path))
|
||||
trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path))
|
||||
|
||||
def find_libs(search_path):
|
||||
"""Search for libnvinfer.so in "search_path"."""
|
||||
fl = set()
|
||||
if os.path.exists(search_path) and os.path.isdir(search_path):
|
||||
fl.update([os.path.realpath(os.path.join(search_path, x))
|
||||
for x in os.listdir(search_path) if 'libnvinfer.so' in x])
|
||||
fl.update([
|
||||
os.path.realpath(os.path.join(search_path, x))
|
||||
for x in os.listdir(search_path)
|
||||
if 'libnvinfer.so' in x
|
||||
])
|
||||
return fl
|
||||
|
||||
possible_files = find_libs(trt_install_path)
|
||||
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
|
||||
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
|
||||
|
||||
def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
|
||||
"""Check the compatibility between tensorrt and cudnn/cudart libraries."""
|
||||
ldd_bin = which('ldd') or '/usr/bin/ldd'
|
||||
ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
|
||||
cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
|
||||
cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
|
||||
cudnn = None
|
||||
cudart = None
|
||||
for line in ldd_out:
|
||||
if 'libcudnn.so' in line:
|
||||
cudnn = cudnn_pattern.search(line)
|
||||
elif 'libcudart.so' in line:
|
||||
cudart = cuda_pattern.search(line)
|
||||
if cudnn and len(cudnn.group(1)):
|
||||
cudnn = convert_version_to_int(cudnn.group(1))
|
||||
if cudart and len(cudart.group(1)):
|
||||
cudart = convert_version_to_int(cudart.group(1))
|
||||
return (cudnn == cudnn_ver) and (cudart == cuda_ver)
|
||||
|
||||
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
|
||||
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
|
||||
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
|
||||
highest_ver = [0, None, None]
|
||||
|
||||
for lib_file in possible_files:
|
||||
if is_compatible(lib_file, cuda_ver, cudnn_ver):
|
||||
if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
|
||||
matches = nvinfer_pattern.search(lib_file)
|
||||
if len(matches.groups()) == 0:
|
||||
continue
|
||||
@ -1029,12 +1040,13 @@ def set_tf_tensorrt_install_path(environ_cp):
|
||||
# Try another alternative from ldconfig.
|
||||
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
|
||||
ldconfig_output = run_shell([ldconfig_bin, '-p'])
|
||||
search_result = re.search(
|
||||
'.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output)
|
||||
search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)',
|
||||
ldconfig_output)
|
||||
if search_result:
|
||||
libnvinfer_path_from_ldconfig = search_result.group(2)
|
||||
if os.path.exists(libnvinfer_path_from_ldconfig):
|
||||
if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver):
|
||||
if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver,
|
||||
cudnn_ver):
|
||||
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
|
||||
tf_tensorrt_version = search_result.group(1)
|
||||
break
|
||||
|
@ -154,6 +154,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_s390x",
|
||||
values = {"cpu": "s390x"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "debug",
|
||||
values = {
|
||||
@ -459,6 +465,15 @@ filegroup(
|
||||
tf_cc_shared_object(
|
||||
name = "libtensorflow_framework.so",
|
||||
framework_so = [],
|
||||
linkopts = select({
|
||||
"//tensorflow:darwin": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:windows_msvc": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
|
||||
"$(location //tensorflow:tf_framework_version_script.lds)",
|
||||
],
|
||||
}),
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
@ -468,6 +483,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/stream_executor:stream_executor_impl",
|
||||
"//tensorflow:tf_framework_version_script.lds",
|
||||
] + tf_additional_binary_deps(),
|
||||
)
|
||||
|
||||
@ -571,3 +587,13 @@ py_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//tensorflow/python:no_contrib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc",
|
||||
deps = ["@grpc"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc++",
|
||||
deps = ["@grpc//:grpc++"],
|
||||
)
|
||||
|
@ -2068,7 +2068,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
|
||||
TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options, TF_Status* status) {
|
||||
GraphDef def;
|
||||
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
|
||||
if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
|
||||
graph_def->length)) {
|
||||
status->status = InvalidArgument("Invalid GraphDef");
|
||||
return nullptr;
|
||||
}
|
||||
@ -2098,7 +2099,8 @@ void TF_GraphImportGraphDefWithReturnOutputs(
|
||||
return;
|
||||
}
|
||||
GraphDef def;
|
||||
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
|
||||
if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
|
||||
graph_def->length)) {
|
||||
status->status = InvalidArgument("Invalid GraphDef");
|
||||
return;
|
||||
}
|
||||
|
@ -287,7 +287,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
|
||||
const int64 result_index = compile_result.aot->result_buffer_index();
|
||||
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
|
||||
if (result_index < 0 || result_index > temp_sizes.size()) {
|
||||
if (result_index < 0 || result_index >= temp_sizes.size()) {
|
||||
return errors::InvalidArgument("result index: ", result_index,
|
||||
" is outside the range of temp sizes: [0,",
|
||||
temp_sizes.size(), ")");
|
||||
|
@ -39,10 +39,10 @@ tf_cc_binary(
|
||||
srcs = ["grpc_service_main.cc"],
|
||||
deps = [
|
||||
":grpc_service",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -54,6 +54,7 @@ tf_cc_test(
|
||||
],
|
||||
deps = [
|
||||
":grpc_stub",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
@ -61,7 +62,6 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -71,9 +71,9 @@ cc_library(
|
||||
hdrs = ["grpc_service.h"],
|
||||
deps = [
|
||||
":xla_service_proto",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
@ -2550,7 +2550,6 @@ cc_library(
|
||||
name = "hlo_tfgraph_builder",
|
||||
srcs = ["hlo_tfgraph_builder.cc"],
|
||||
hdrs = ["hlo_tfgraph_builder.h"],
|
||||
visibility = ["//tensorflow/compiler/xla/tools:__pkg__"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
|
@ -1515,6 +1515,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
|
||||
// Remaining instructions with special values.
|
||||
case HloOpcode::kCall:
|
||||
return eq_computations(to_apply(), other.to_apply());
|
||||
case HloOpcode::kConditional:
|
||||
return eq_computations(true_computation(), other.true_computation()) &&
|
||||
eq_computations(false_computation(), other.false_computation());
|
||||
|
@ -924,6 +924,40 @@ TEST_F(HloInstructionTest, IdenticalInstructions) {
|
||||
*HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, IdenticalCallInstructions) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
subcomp1 (x: f32[]) -> f32[] {
|
||||
x = f32[] parameter(0)
|
||||
ROOT n = f32[] sine(x)
|
||||
}
|
||||
|
||||
subcomp2 (x: f32[]) -> f32[] {
|
||||
x = f32[] parameter(0)
|
||||
ROOT n = f32[] cosine(x)
|
||||
}
|
||||
|
||||
ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
|
||||
p = f32[] parameter(0)
|
||||
t1 = f32[] call(p), to_apply=subcomp1
|
||||
t2 = f32[] call(p), to_apply=subcomp1
|
||||
t3 = f32[] call(p), to_apply=subcomp2
|
||||
ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(hlo_string));
|
||||
|
||||
auto* root = module->entry_computation()->root_instruction();
|
||||
auto* t1 = root->operand(0);
|
||||
auto* t2 = root->operand(1);
|
||||
auto* t3 = root->operand(2);
|
||||
|
||||
EXPECT_TRUE(StructuralEqual(*t1, *t2));
|
||||
EXPECT_FALSE(StructuralEqual(*t1, *t3));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, FunctionVisitor) {
|
||||
// Verify the function visitor HloInstruction::Accept visits all instructions
|
||||
// from a root properly given the following graph:
|
||||
|
@ -120,7 +120,10 @@ py_test(
|
||||
name = "decorators_test",
|
||||
srcs = ["decorators_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_windows"],
|
||||
tags = [
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":converters",
|
||||
"//tensorflow/contrib/autograph/core:test_lib",
|
||||
|
@ -51,7 +51,7 @@ def for_stmt(iter_, extra_test, body, init_state):
|
||||
Args:
|
||||
iter_: The entity being iterated over.
|
||||
extra_test: Callable with the state as arguments, and boolean return type.
|
||||
An additionnal loop condition.
|
||||
An additional loop condition.
|
||||
body: Callable with the iterate and the state as arguments, and
|
||||
state as return type. The actual loop body.
|
||||
init_state: Tuple containing the initial state.
|
||||
|
@ -286,7 +286,7 @@ class Forward(object):
|
||||
|
||||
# TODO(alexbw): see if we can simplify by visiting breadth-first
|
||||
def visit(self, node):
|
||||
"""Depth-first walking the CFG, applying dataflow information propagtion."""
|
||||
"""Depth-first walking the CFG, applying dataflow info propagation."""
|
||||
# node.value is None only for the exit CfgNode.
|
||||
if not node.value:
|
||||
return
|
||||
|
@ -218,7 +218,7 @@ class Base(gast.NodeTransformer):
|
||||
|
||||
# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
|
||||
def apply_to_single_assignments(self, targets, values, apply_fn):
|
||||
"""Applies a fuction to each individual assignment.
|
||||
"""Applies a function to each individual assignment.
|
||||
|
||||
This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
|
||||
It tries to break down the unpacking if possible. In effect, it has the same
|
||||
@ -246,7 +246,7 @@ class Base(gast.NodeTransformer):
|
||||
targets field of an ast.Assign node.
|
||||
values: an AST node.
|
||||
apply_fn: a function of a single argument, which will be called with the
|
||||
respective nodes of each single assignment. The signaure is
|
||||
respective nodes of each single assignment. The signature is
|
||||
apply_fn(target, value), no return value.
|
||||
"""
|
||||
if not isinstance(targets, (list, tuple)):
|
||||
|
@ -336,40 +336,14 @@ endif()
|
||||
# MKL Support
|
||||
if (tensorflow_ENABLE_MKL_SUPPORT)
|
||||
add_definitions(-DINTEL_MKL -DEIGEN_USE_VML)
|
||||
if (WIN32)
|
||||
find_path(MKL_HOME_PLATFORM mkl
|
||||
PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../
|
||||
$ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../
|
||||
PATH_SUFFIXES windows)
|
||||
set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include)
|
||||
set(MKL_LINK_DIRS
|
||||
${MKL_HOME_PLATFORM}/mkl/lib/intel64
|
||||
${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt
|
||||
${MKL_HOME_PLATFORM}/compiler/lib/intel64
|
||||
${MKL_HOME_PLATFORM}/mkl/tools/builder/lib)
|
||||
set(MKL_REDIST_DLL_DIRS
|
||||
${MKL_HOME_PLATFORM}/redist/intel64/mkl
|
||||
${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt
|
||||
${MKL_HOME_PLATFORM}/redist/intel64/compiler)
|
||||
list(APPEND tensorflow_EXTERNAL_LIBRARIES
|
||||
mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64)
|
||||
endif()
|
||||
if (UNIX)
|
||||
# Fix me: complete the path on linux
|
||||
find_path(MKL_HOME_PLATFORM mkl
|
||||
HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../
|
||||
$ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../
|
||||
PATH_SUFFIXES linux)
|
||||
set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include)
|
||||
set(MKL_LINK_DIRS) # incompleted
|
||||
set(MKL_REDIST_SO_DIRS) # incompleted
|
||||
endif()
|
||||
include_directories(${MKL_INCLUDE_DIRS})
|
||||
link_directories(${MKL_LINK_DIRS})
|
||||
include(mkl)
|
||||
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES})
|
||||
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination)
|
||||
include_directories(${mkl_INCLUDE_DIRS})
|
||||
if (tensorflow_ENABLE_MKLDNN_SUPPORT)
|
||||
include(mkldnn)
|
||||
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES})
|
||||
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn)
|
||||
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
|
||||
include_directories(${mkldnn_INCLUDE_DIRS})
|
||||
else (tensorflow_ENABLE_MKLDNN_SUPPORT)
|
||||
add_definitions(-DINTEL_MKL_ML)
|
||||
|
@ -16,15 +16,15 @@ include (ExternalProject)
|
||||
|
||||
set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion)
|
||||
set(double_conversion_URL https://github.com/google/double-conversion.git)
|
||||
set(double_conversion_TAG 5664746)
|
||||
set(double_conversion_TAG 3992066a95b823efc8ccc1baf82a1cfc73f6e9b8)
|
||||
set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR})
|
||||
set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so)
|
||||
set(double_conversion_INCLUDES ${double_conversion_BUILD})
|
||||
|
||||
if(WIN32)
|
||||
set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib)
|
||||
set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/$(Configuration)/double-conversion.lib)
|
||||
else()
|
||||
set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a)
|
||||
set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/libdouble-conversion.a)
|
||||
endif()
|
||||
|
||||
set(double_conversion_HEADERS
|
||||
|
68
tensorflow/contrib/cmake/external/mkl.cmake
vendored
Normal file
68
tensorflow/contrib/cmake/external/mkl.cmake
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
include (ExternalProject)
|
||||
|
||||
# NOTE: Different from mkldnn.cmake, this file is meant to download mkl libraries
|
||||
set(mkl_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include)
|
||||
set(mkl_BIN_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/bin)
|
||||
set(mkl_WIN mklml_win_2018.0.3.20180406.zip) # match for v0.14
|
||||
set(mkl_MAC mklml_mac_2018.0.3.20180406.tgz)
|
||||
set(mkl_LNX mklml_lnx_2018.0.3.20180406.tgz)
|
||||
set(mkl_TAG v0.14)
|
||||
set(mkl_URL https://github.com/intel/mkl-dnn/releases)
|
||||
|
||||
if (WIN32)
|
||||
set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_WIN})
|
||||
list(APPEND mkl_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.lib)
|
||||
list(APPEND mkl_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.lib)
|
||||
list(APPEND mkl_SHARED_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.dll)
|
||||
list(APPEND mkl_SHARED_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.dll)
|
||||
elseif (UNIX)
|
||||
set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_LNX})
|
||||
list(APPEND mkl_SHARED_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5.so)
|
||||
list(APPEND mkl_SHARED_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_gnu.so)
|
||||
list(APPEND mkl_SHARED_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_intel.so)
|
||||
elseif (APPLE)
|
||||
set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_MAC})
|
||||
#TODO need more information
|
||||
endif ()
|
||||
|
||||
ExternalProject_Add(mkl
|
||||
PREFIX mkl
|
||||
URL ${mkl_DOWNLOAD_URL}
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
UPDATE_COMMAND ""
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND "")
|
||||
|
||||
# put mkl dynamic libraries in one bin directory
|
||||
add_custom_target(mkl_create_destination_dir
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${mkl_BIN_DIRS}
|
||||
DEPENDS mkl)
|
||||
|
||||
add_custom_target(mkl_copy_shared_to_destination DEPENDS mkl_create_destination_dir)
|
||||
|
||||
foreach(dll_file ${mkl_SHARED_LIBRARIES})
|
||||
add_custom_command(TARGET mkl_copy_shared_to_destination PRE_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dll_file} ${mkl_BIN_DIRS})
|
||||
endforeach()
|
12
tensorflow/contrib/cmake/external/mkldnn.cmake
vendored
12
tensorflow/contrib/cmake/external/mkldnn.cmake
vendored
@ -22,8 +22,11 @@ set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291)
|
||||
if(WIN32)
|
||||
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
|
||||
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib)
|
||||
set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.dll)
|
||||
set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release)
|
||||
else()
|
||||
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib)
|
||||
set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.dll)
|
||||
endif()
|
||||
else()
|
||||
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a)
|
||||
@ -31,6 +34,7 @@ endif()
|
||||
|
||||
ExternalProject_Add(mkldnn
|
||||
PREFIX mkldnn
|
||||
DEPENDS mkl
|
||||
GIT_REPOSITORY ${mkldnn_URL}
|
||||
GIT_TAG ${mkldnn_TAG}
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
@ -40,5 +44,11 @@ ExternalProject_Add(mkldnn
|
||||
CMAKE_CACHE_ARGS
|
||||
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||
-DMKLINC:STRING=${MKL_INCLUDE_DIRS}
|
||||
-DMKLINC:STRING=${mkl_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
# since mkldnn depends on mkl, copy the mkldnn.dll together with mklml.dll to mkl_bin_dirs
|
||||
add_custom_target(mkldnn_copy_shared_to_destination DEPENDS mkldnn)
|
||||
|
||||
add_custom_command(TARGET mkldnn_copy_shared_to_destination PRE_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${mkldnn_SHARED_LIBRARIES} ${mkl_BIN_DIRS})
|
||||
|
@ -755,26 +755,65 @@ set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt")
|
||||
file(WRITE "${api_init_list_file}" "${api_init_files}")
|
||||
|
||||
# Run create_python_api.py to generate __init__.py files.
|
||||
add_custom_command(
|
||||
OUTPUT ${api_init_files}
|
||||
DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
|
||||
|
||||
# tensorflow/__init__.py depends on files generated in this step. So, remove it while
|
||||
# this step is running since the files aren't there yet.
|
||||
COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
|
||||
### TODO
|
||||
# In order to download and compile MKL/MKL-DNN automatically in cmake script, mkl-built libraries should be added to system path
|
||||
# to be loaded by python executor. However `add_custom_command` has an issue with `COMMAND ${CMAKE_COMMAND} -E env PATH=`, where
|
||||
# arguments of multiple paths (such as D:/;D:/mkl) will be parsed in to seperate string without semicolon and that command fail to
|
||||
# recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue.
|
||||
# To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem,
|
||||
# and should be removed if the path issue can be resolved.
|
||||
###
|
||||
|
||||
# Run create_python_api.py to generate API init files.
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
|
||||
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
|
||||
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
|
||||
"--package=tensorflow.python"
|
||||
"--apiname=tensorflow"
|
||||
"${api_init_list_file}"
|
||||
if (tensorflow_ENABLE_MKL_SUPPORT)
|
||||
# add mkl dist dlls to system path for python
|
||||
# TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths,
|
||||
# so we have to specify only one path in it to work around the issue. We need this if/else
|
||||
# to protect overwriting CUDA environments
|
||||
set(PY_RUNTIME_ENV ${mkl_BIN_DIRS})
|
||||
add_custom_command(
|
||||
OUTPUT ${api_init_files}
|
||||
DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
|
||||
|
||||
COMMENT "Generating __init__.py files for Python API."
|
||||
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
|
||||
)
|
||||
# tensorflow/__init__.py depends on files generated in this step. So, remove it while
|
||||
# this step is running since the files aren't there yet.
|
||||
COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
|
||||
|
||||
# Run create_python_api.py to generate API init files.
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE}
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
|
||||
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
|
||||
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
|
||||
"--package=tensorflow.python"
|
||||
"--apiname=tensorflow"
|
||||
"${api_init_list_file}"
|
||||
|
||||
COMMENT "Generating __init__.py files for Python API."
|
||||
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
|
||||
VERBATIM
|
||||
)
|
||||
else (tensorflow_ENABLE_MKL_SUPPORT)
|
||||
add_custom_command(
|
||||
OUTPUT ${api_init_files}
|
||||
DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
|
||||
|
||||
# tensorflow/__init__.py depends on files generated in this step. So, remove it while
|
||||
# this step is running since the files aren't there yet.
|
||||
COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
|
||||
|
||||
# Run create_python_api.py to generate API init files.
|
||||
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
|
||||
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
|
||||
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
|
||||
"--package=tensorflow.python"
|
||||
"--apiname=tensorflow"
|
||||
"${api_init_list_file}"
|
||||
|
||||
COMMENT "Generating __init__.py files for Python API."
|
||||
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
|
||||
)
|
||||
endif (tensorflow_ENABLE_MKL_SUPPORT)
|
||||
|
||||
add_custom_target(tf_python_api SOURCES ${api_init_files})
|
||||
add_dependencies(tf_python_api tf_python_ops)
|
||||
|
@ -145,3 +145,8 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/
|
||||
# unsupported Eigen directory
|
||||
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/
|
||||
DESTINATION include/unsupported/Eigen)
|
||||
# mkl
|
||||
if (tensorflow_ENABLE_MKL_SUPPORT)
|
||||
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/
|
||||
DESTINATION include/mkl)
|
||||
endif (tensorflow_ENABLE_MKL_SUPPORT)
|
||||
|
@ -46,7 +46,7 @@ document.
|
||||
Imagine that we want to constrain the recall of a binary classifier to be at
|
||||
least 90%. Since the recall is proportional to the number of true positive
|
||||
classifications, which itself is a sum of indicator functions, this constraint
|
||||
is non-differentible, and therefore cannot be used in a problem that will be
|
||||
is non-differentiable, and therefore cannot be used in a problem that will be
|
||||
optimized using a (stochastic) gradient-based algorithm.
|
||||
|
||||
For this and similar problems, TFCO supports so-called *proxy constraints*,
|
||||
|
@ -169,8 +169,8 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
|
||||
del old_inactive # Needed by the condition, but not the body.
|
||||
iteration += 1
|
||||
scale = (1.0 - standard_ops.reduce_sum(
|
||||
matrix, axis=0, keep_dims=True)) / standard_ops.maximum(
|
||||
1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True))
|
||||
matrix, axis=0, keepdims=True)) / standard_ops.maximum(
|
||||
1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True))
|
||||
matrix += scale * inactive
|
||||
new_inactive = standard_ops.to_float(matrix > 0)
|
||||
matrix *= new_inactive
|
||||
@ -206,10 +206,10 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
|
||||
|
||||
# For numerical reasons, make sure that the largest matrix element is zero
|
||||
# before exponentiating.
|
||||
log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True)
|
||||
log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True)
|
||||
log_matrix -= standard_ops.log(
|
||||
standard_ops.reduce_sum(
|
||||
standard_ops.exp(log_matrix), axis=0, keep_dims=True))
|
||||
standard_ops.exp(log_matrix), axis=0, keepdims=True))
|
||||
return log_matrix
|
||||
|
||||
|
||||
|
@ -58,6 +58,7 @@ class SlideDatasetTest(test.TestCase):
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# stride < window_size.
|
||||
# Slide over a finite input, where the window_size divides the
|
||||
# total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7})
|
||||
@ -71,11 +72,9 @@ class SlideDatasetTest(test.TestCase):
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Slide over a finite input, where the window_size does not
|
||||
# divide the total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9})
|
||||
|
||||
num_batches = (20 * 7 - 17) // 9 + 1
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
@ -86,6 +85,41 @@ class SlideDatasetTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# stride == window_size.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 14})
|
||||
num_batches = 20 * 7 // 14
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i*14 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# stride > window_size.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 10, stride: 14})
|
||||
num_batches = 20 * 7 // 14
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(10):
|
||||
self.assertAllEqual(component[(i*14 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
# Drop the last batch which is smaller than window_size.
|
||||
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 19})
|
||||
num_batches = (20 * 7 - 7) // 19 # = 19 * 7 // 19
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i*19 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Slide over a finite input, which is less than window_size,
|
||||
# should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4})
|
||||
@ -108,10 +142,6 @@ class SlideDatasetTest(test.TestCase):
|
||||
# Invalid stride should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5})
|
||||
|
||||
def assertSparseValuesEqual(self, a, b):
|
||||
self.assertAllEqual(a.indices, b.indices)
|
||||
|
@ -86,7 +86,7 @@ def sliding_window_batch(window_size, stride=1):
|
||||
elements in the sliding window.
|
||||
stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
|
||||
steps moving the sliding window forward for one iteration. The default
|
||||
is `1`. It must be in `[1, window_size)`.
|
||||
is `1`. It must be positive.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
|
@ -0,0 +1,909 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "nmt_with_attention.ipynb",
|
||||
"version": "0.3.2",
|
||||
"views": {},
|
||||
"default_view": {},
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U",
|
||||
"timestamp": 1527858391290
|
||||
},
|
||||
{
|
||||
"file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv",
|
||||
"timestamp": 1527776041613
|
||||
}
|
||||
],
|
||||
"private_outputs": true,
|
||||
"collapsed_sections": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"metadata": {
|
||||
"id": "AOpGoE2T-YXS",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"##### Copyright 2018 The TensorFlow Authors.\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\").\n",
|
||||
"\n",
|
||||
"# Neural Machine Translation with Attention\n",
|
||||
"\n",
|
||||
"<table align=\"left\"><td>\n",
|
||||
"<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\">\n",
|
||||
" <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n",
|
||||
"</td><td>\n",
|
||||
"<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on Github</a></td></table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "CiwtNgENbx2g",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n",
|
||||
"\n",
|
||||
"After training the model in this notebook, you will be able to input a Spanish sentence, such as *\"¿todavia estan en casa?\"*, and return the English translation: *\"are you still at home?\"*\n",
|
||||
"\n",
|
||||
"The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n",
|
||||
"\n",
|
||||
"<img src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\">\n",
|
||||
"\n",
|
||||
"Note: This example takes approximately 10 mintues to run on a single P100 GPU."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "tnxXKDjq3jEL",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from __future__ import absolute_import, division, print_function\n",
|
||||
"\n",
|
||||
"# Import TensorFlow >= 1.9 and enable eager execution\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"tf.enable_eager_execution()\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"\n",
|
||||
"import unicodedata\n",
|
||||
"import re\n",
|
||||
"import numpy as np\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"print(tf.__version__)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "wfodePkj3jEa",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Download and prepare the dataset\n",
|
||||
"\n",
|
||||
"We'll use a language dataset provided by http://www.manythings.org/anki/. This dataset contains language translation pairs in the format:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"May I borrow this book?\t¿Puedo tomar prestado este libro?\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"There are a variety of languages available, but we'll use the English-Spanish dataset. For convenience, we've hosted a copy of this dataset on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps we'll take to prepare the data:\n",
|
||||
"\n",
|
||||
"1. Add a *start* and *end* token to each sentence.\n",
|
||||
"2. Clean the sentences by removing special characters.\n",
|
||||
"3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n",
|
||||
"4. Pad each sentence to a maximum length."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "kRVATYOgJs1b",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Download the file\n",
|
||||
"path_to_zip = tf.keras.utils.get_file(\n",
|
||||
" 'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip', \n",
|
||||
" extract=True)\n",
|
||||
"\n",
|
||||
"path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\""
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "rd0jw-eC3jEh",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Converts the unicode file to ascii\n",
|
||||
"def unicode_to_ascii(s):\n",
|
||||
" return ''.join(c for c in unicodedata.normalize('NFD', s)\n",
|
||||
" if unicodedata.category(c) != 'Mn')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def preprocess_sentence(w):\n",
|
||||
" w = unicode_to_ascii(w.lower().strip())\n",
|
||||
" \n",
|
||||
" # creating a space between a word and the punctuation following it\n",
|
||||
" # eg: \"he is a boy.\" => \"he is a boy .\" \n",
|
||||
" # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n",
|
||||
" w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
|
||||
" w = re.sub(r'[\" \"]+', \" \", w)\n",
|
||||
" \n",
|
||||
" # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n",
|
||||
" w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
|
||||
" \n",
|
||||
" w = w.rstrip().strip()\n",
|
||||
" \n",
|
||||
" # adding a start and an end token to the sentence\n",
|
||||
" # so that the model know when to start and stop predicting.\n",
|
||||
" w = '<start> ' + w + ' <end>'\n",
|
||||
" return w"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "OHn4Dct23jEm",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# 1. Remove the accents\n",
|
||||
"# 2. Clean the sentences\n",
|
||||
"# 3. Return word pairs in the format: [ENGLISH, SPANISH]\n",
|
||||
"def create_dataset(path, num_examples):\n",
|
||||
" lines = open(path, encoding='UTF-8').read().strip().split('\\n')\n",
|
||||
" \n",
|
||||
" word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n",
|
||||
" \n",
|
||||
" return word_pairs"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "9xbqO7Iie9bb",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n",
|
||||
"# (e.g., 5 -> \"dad\") for each language,\n",
|
||||
"class LanguageIndex():\n",
|
||||
" def __init__(self, lang):\n",
|
||||
" self.lang = lang\n",
|
||||
" self.word2idx = {}\n",
|
||||
" self.idx2word = {}\n",
|
||||
" self.vocab = set()\n",
|
||||
" \n",
|
||||
" self.create_index()\n",
|
||||
" \n",
|
||||
" def create_index(self):\n",
|
||||
" for phrase in self.lang:\n",
|
||||
" self.vocab.update(phrase.split(' '))\n",
|
||||
" \n",
|
||||
" self.vocab = sorted(self.vocab)\n",
|
||||
" \n",
|
||||
" self.word2idx['<pad>'] = 0\n",
|
||||
" for index, word in enumerate(self.vocab):\n",
|
||||
" self.word2idx[word] = index + 1\n",
|
||||
" \n",
|
||||
" for word, index in self.word2idx.items():\n",
|
||||
" self.idx2word[index] = word"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "eAY9k49G3jE_",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def max_length(tensor):\n",
|
||||
" return max(len(t) for t in tensor)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def load_dataset(path, num_examples):\n",
|
||||
" # creating cleaned input, output pairs\n",
|
||||
" pairs = create_dataset(path, num_examples)\n",
|
||||
"\n",
|
||||
" # index language using the class defined above \n",
|
||||
" inp_lang = LanguageIndex(sp for en, sp in pairs)\n",
|
||||
" targ_lang = LanguageIndex(en for en, sp in pairs)\n",
|
||||
" \n",
|
||||
" # Vectorize the input and target languages\n",
|
||||
" \n",
|
||||
" # Spanish sentences\n",
|
||||
" input_tensor = [[inp_lang.word2idx[s] for s in sp.split(' ')] for en, sp in pairs]\n",
|
||||
" \n",
|
||||
" # English sentences\n",
|
||||
" target_tensor = [[targ_lang.word2idx[s] for s in en.split(' ')] for en, sp in pairs]\n",
|
||||
" \n",
|
||||
" # Calculate max_length of input and output tensor\n",
|
||||
" # Here, we'll set those to the longest sentence in the dataset\n",
|
||||
" max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)\n",
|
||||
" \n",
|
||||
" # Padding the input and output tensor to the maximum length\n",
|
||||
" input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, \n",
|
||||
" maxlen=max_length_inp,\n",
|
||||
" padding='post')\n",
|
||||
" \n",
|
||||
" target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor, \n",
|
||||
" maxlen=max_length_tar, \n",
|
||||
" padding='post')\n",
|
||||
" \n",
|
||||
" return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "GOi42V79Ydlr",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Limit the size of the dataset to experiment faster (optional)\n",
|
||||
"\n",
|
||||
"Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "cnxC7q-j3jFD",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Try experimenting with the size of that dataset\n",
|
||||
"num_examples = 30000\n",
|
||||
"input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "4QILQkOs3jFG",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Creating training and validation sets using an 80-20 split\n",
|
||||
"input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n",
|
||||
"\n",
|
||||
"# Show length\n",
|
||||
"len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "rgCLkfv5uO3d",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Create a tf.data dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "TqHsArVZ3jFS",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"BUFFER_SIZE = len(input_tensor_train)\n",
|
||||
"BATCH_SIZE = 64\n",
|
||||
"embedding_dim = 256\n",
|
||||
"units = 1024\n",
|
||||
"vocab_inp_size = len(inp_lang.word2idx)\n",
|
||||
"vocab_tar_size = len(targ_lang.word2idx)\n",
|
||||
"\n",
|
||||
"dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
|
||||
"dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "TNfHIF71ulLu",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Write the encoder and decoder model\n",
|
||||
"\n",
|
||||
"Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n",
|
||||
"\n",
|
||||
"<img src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\">\n",
|
||||
"\n",
|
||||
"The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n",
|
||||
"\n",
|
||||
"Here are the equations that are implemented:\n",
|
||||
"\n",
|
||||
"<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\">\n",
|
||||
"<img src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\">\n",
|
||||
"\n",
|
||||
"We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n",
|
||||
"\n",
|
||||
"* FC = Fully connected (dense) layer\n",
|
||||
"* EO = Encoder output\n",
|
||||
"* H = hidden state\n",
|
||||
"* X = input to the decoder\n",
|
||||
"\n",
|
||||
"And the pseudo-code:\n",
|
||||
"\n",
|
||||
"* `score = FC(tanh(FC(EO) + FC(H)))`\n",
|
||||
"* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n",
|
||||
"* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n",
|
||||
"* `embedding output` = The input to the decoder X is passed through an embedding layer.\n",
|
||||
"* `merged vector = concat(embedding output, context vector)`\n",
|
||||
"* This merged vector is then given to the GRU\n",
|
||||
" \n",
|
||||
"The shapes of all the vectors at each step have been specified in the comments in the code:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "avyJ_4VIUoHb",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def gru(units):\n",
|
||||
" # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n",
|
||||
" # the code automatically does that.\n",
|
||||
" if tf.test.is_gpu_available():\n",
|
||||
" return tf.keras.layers.CuDNNGRU(units, \n",
|
||||
" return_sequences=True, \n",
|
||||
" return_state=True, \n",
|
||||
" recurrent_initializer='glorot_uniform')\n",
|
||||
" else:\n",
|
||||
" return tf.keras.layers.GRU(units, \n",
|
||||
" return_sequences=True, \n",
|
||||
" return_state=True, \n",
|
||||
" recurrent_activation='sigmoid', \n",
|
||||
" recurrent_initializer='glorot_uniform')"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "nZ2rI24i3jFg",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"class Encoder(tf.keras.Model):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
|
||||
" super(Encoder, self).__init__()\n",
|
||||
" self.batch_sz = batch_sz\n",
|
||||
" self.enc_units = enc_units\n",
|
||||
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
|
||||
" self.gru = gru(self.enc_units)\n",
|
||||
" \n",
|
||||
" def call(self, x, hidden):\n",
|
||||
" x = self.embedding(x)\n",
|
||||
" output, state = self.gru(x, initial_state = hidden) \n",
|
||||
" return output, state\n",
|
||||
" \n",
|
||||
" def initialize_hidden_state(self):\n",
|
||||
" return tf.zeros((self.batch_sz, self.enc_units))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "yJ_B3mhW3jFk",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"class Decoder(tf.keras.Model):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
|
||||
" super(Decoder, self).__init__()\n",
|
||||
" self.batch_sz = batch_sz\n",
|
||||
" self.dec_units = dec_units\n",
|
||||
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
|
||||
" self.gru = gru(self.dec_units)\n",
|
||||
" self.fc = tf.keras.layers.Dense(vocab_size)\n",
|
||||
" \n",
|
||||
" # used for attention\n",
|
||||
" self.W1 = tf.keras.layers.Dense(self.dec_units)\n",
|
||||
" self.W2 = tf.keras.layers.Dense(self.dec_units)\n",
|
||||
" self.V = tf.keras.layers.Dense(1)\n",
|
||||
" \n",
|
||||
" def call(self, x, hidden, enc_output):\n",
|
||||
" # enc_output shape == (batch_size, max_length, hidden_size)\n",
|
||||
" \n",
|
||||
" # hidden shape == (batch_size, hidden size)\n",
|
||||
" # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
|
||||
" # we are doing this to perform addition to calculate the score\n",
|
||||
" hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
|
||||
" \n",
|
||||
" # score shape == (batch_size, max_length, hidden_size)\n",
|
||||
" score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n",
|
||||
" \n",
|
||||
" # attention_weights shape == (batch_size, max_length, 1)\n",
|
||||
" # we get 1 at the last axis because we are applying score to self.V\n",
|
||||
" attention_weights = tf.nn.softmax(self.V(score), axis=1)\n",
|
||||
" \n",
|
||||
" # context_vector shape after sum == (batch_size, hidden_size)\n",
|
||||
" context_vector = attention_weights * enc_output\n",
|
||||
" context_vector = tf.reduce_sum(context_vector, axis=1)\n",
|
||||
" \n",
|
||||
" # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
|
||||
" x = self.embedding(x)\n",
|
||||
" \n",
|
||||
" # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
|
||||
" x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
|
||||
" \n",
|
||||
" # passing the concatenated vector to the GRU\n",
|
||||
" output, state = self.gru(x)\n",
|
||||
" \n",
|
||||
" # output shape == (batch_size * max_length, hidden_size)\n",
|
||||
" output = tf.reshape(output, (-1, output.shape[2]))\n",
|
||||
" \n",
|
||||
" # output shape == (batch_size * max_length, vocab)\n",
|
||||
" x = self.fc(output)\n",
|
||||
" \n",
|
||||
" return x, state, attention_weights\n",
|
||||
" \n",
|
||||
" def initialize_hidden_state(self):\n",
|
||||
" return tf.zeros((self.batch_sz, self.dec_units))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "P5UY8wko3jFp",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
|
||||
"decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "_ch_71VbIRfK",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Define the optimizer and the loss function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "WmTHr5iV3jFr",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"optimizer = tf.train.AdamOptimizer()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def loss_function(real, pred):\n",
|
||||
" mask = 1 - np.equal(real, 0)\n",
|
||||
" loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
|
||||
" return tf.reduce_mean(loss_)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "hpObfY22IddU",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Training\n",
|
||||
"\n",
|
||||
"1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n",
|
||||
"2. The encoder output, encoder hidden state and the decoder input (which is the *start token*) is passed to the decoder.\n",
|
||||
"3. The decoder returns the *predictions* and the *decoder hidden state*.\n",
|
||||
"4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n",
|
||||
"5. Use *teacher forcing* to decide the next input to the decoder.\n",
|
||||
"6. *Teacher forcing* is the technique where the *target word* is passed as the *next input* to the decoder.\n",
|
||||
"7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "ddefjBMa3jF0",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"EPOCHS = 10\n",
|
||||
"\n",
|
||||
"for epoch in range(EPOCHS):\n",
|
||||
" start = time.time()\n",
|
||||
" \n",
|
||||
" hidden = encoder.initialize_hidden_state()\n",
|
||||
" total_loss = 0\n",
|
||||
" \n",
|
||||
" for (batch, (inp, targ)) in enumerate(dataset):\n",
|
||||
" loss = 0\n",
|
||||
" \n",
|
||||
" with tf.GradientTape() as tape:\n",
|
||||
" enc_output, enc_hidden = encoder(inp, hidden)\n",
|
||||
" \n",
|
||||
" dec_hidden = enc_hidden\n",
|
||||
" \n",
|
||||
" dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * BATCH_SIZE, 1) \n",
|
||||
" \n",
|
||||
" # Teacher forcing - feeding the target as the next input\n",
|
||||
" for t in range(1, targ.shape[1]):\n",
|
||||
" # passing enc_output to the decoder\n",
|
||||
" predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n",
|
||||
" \n",
|
||||
" loss += loss_function(targ[:, t], predictions)\n",
|
||||
" \n",
|
||||
" # using teacher forcing\n",
|
||||
" dec_input = tf.expand_dims(targ[:, t], 1)\n",
|
||||
" \n",
|
||||
" total_loss += (loss / int(targ.shape[1]))\n",
|
||||
" \n",
|
||||
" variables = encoder.variables + decoder.variables\n",
|
||||
" \n",
|
||||
" gradients = tape.gradient(loss, variables)\n",
|
||||
" \n",
|
||||
" optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n",
|
||||
"\n",
|
||||
" if batch % 100 == 0:\n",
|
||||
" print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
|
||||
" batch,\n",
|
||||
" loss.numpy() / int(targ.shape[1])))\n",
|
||||
" \n",
|
||||
" print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
|
||||
" total_loss/len(input_tensor)))\n",
|
||||
" print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "mU3Ce8M6I3rz",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Translate\n",
|
||||
"\n",
|
||||
"* The evaluate function is similar to the training loop, except we don't use *teacher forcing* here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n",
|
||||
"* Stop predicting when the model predicts the *end token*.\n",
|
||||
"* And store the *attention weights for every time step*.\n",
|
||||
"\n",
|
||||
"Note: The encoder output is calculated only once for one input."
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "EbQpyYs13jF_",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
|
||||
" attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
|
||||
" \n",
|
||||
" sentence = preprocess_sentence(sentence)\n",
|
||||
"\n",
|
||||
" inputs = [inp_lang.word2idx[i] for i in sentence.split(' ')]\n",
|
||||
" inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')\n",
|
||||
" inputs = tf.convert_to_tensor(inputs)\n",
|
||||
" \n",
|
||||
" result = ''\n",
|
||||
"\n",
|
||||
" hidden = [tf.zeros((1, units))]\n",
|
||||
" enc_out, enc_hidden = encoder(inputs, hidden)\n",
|
||||
"\n",
|
||||
" dec_hidden = enc_hidden\n",
|
||||
" dec_input = tf.expand_dims([targ_lang.word2idx['<start>']], 0)\n",
|
||||
"\n",
|
||||
" for t in range(max_length_targ):\n",
|
||||
" predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n",
|
||||
" \n",
|
||||
" # storing the attention weigths to plot later on\n",
|
||||
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
|
||||
" attention_plot[t] = attention_weights.numpy()\n",
|
||||
"\n",
|
||||
" predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
|
||||
"\n",
|
||||
" result += targ_lang.idx2word[predicted_id] + ' '\n",
|
||||
"\n",
|
||||
" if targ_lang.idx2word[predicted_id] == '<end>':\n",
|
||||
" return result, sentence, attention_plot\n",
|
||||
" \n",
|
||||
" # the predicted ID is fed back into the model\n",
|
||||
" dec_input = tf.expand_dims([predicted_id], 0)\n",
|
||||
"\n",
|
||||
" return result, sentence, attention_plot"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "s5hQWlbN3jGF",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# function for plotting the attention weights\n",
|
||||
"def plot_attention(attention, sentence, predicted_sentence):\n",
|
||||
" fig = plt.figure(figsize=(10,10))\n",
|
||||
" ax = fig.add_subplot(1, 1, 1)\n",
|
||||
" ax.matshow(attention, cmap='viridis')\n",
|
||||
" \n",
|
||||
" fontdict = {'fontsize': 14}\n",
|
||||
" \n",
|
||||
" ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n",
|
||||
" ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n",
|
||||
"\n",
|
||||
" plt.show()"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "sl9zUHzg3jGI",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n",
|
||||
" result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n",
|
||||
" \n",
|
||||
" print('Input: {}'.format(sentence))\n",
|
||||
" print('Predicted translation: {}'.format(result))\n",
|
||||
" \n",
|
||||
" attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n",
|
||||
" plot_attention(attention_plot, sentence.split(' '), result.split(' '))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "WrAM0FDomq3E",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "zSx2iM36EZQZ",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "A3LLCx3ZE0Ls",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "DUQVLVqUE1YW",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# wrong translation\n",
|
||||
"translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"id": "RTe5P5ioMJwN",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Next steps\n",
|
||||
"\n",
|
||||
"* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n",
|
||||
"* Experiment with training on a larger dataset, or using more epochs\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
@ -24,6 +24,7 @@ from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
|
||||
from tensorflow.contrib.gan.python import train as tfgan_train
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator.canned import head
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import metrics as metrics_lib
|
||||
|
||||
@ -182,7 +183,10 @@ class GANHead(head._Head): # pylint: disable=protected-access
|
||||
if mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=model_fn_lib.ModeKeys.PREDICT,
|
||||
predictions=gan_model.generated_data)
|
||||
predictions=gan_model.generated_data,
|
||||
export_outputs={
|
||||
'predict': export_output.PredictOutput(gan_model.generated_data)
|
||||
})
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
gan_loss = self.create_loss(
|
||||
features=None, mode=mode, logits=gan_model, labels=None)
|
||||
|
@ -26,8 +26,11 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import training
|
||||
|
||||
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
|
||||
|
||||
def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument
|
||||
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
|
||||
@ -71,13 +74,15 @@ class GANHeadTest(test.TestCase):
|
||||
return {}
|
||||
|
||||
def _test_modes_helper(self, mode):
|
||||
self.gan_head.create_estimator_spec(
|
||||
return self.gan_head.create_estimator_spec(
|
||||
features=None,
|
||||
mode=mode,
|
||||
logits=get_gan_model())
|
||||
|
||||
def test_modes_predict(self):
|
||||
self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
|
||||
spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
|
||||
self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict'),
|
||||
spec.export_outputs.keys())
|
||||
|
||||
def test_modes_eval(self):
|
||||
self._test_modes_helper(model_fn_lib.ModeKeys.EVAL)
|
||||
|
@ -57,7 +57,7 @@ Status GdrServer::Init() {
|
||||
new GdrWorker(env, remote_memory_manager_.get()));
|
||||
};
|
||||
TF_RETURN_IF_ERROR(
|
||||
GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func));
|
||||
GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func));
|
||||
|
||||
return remote_memory_manager_->Init();
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ namespace optimized_ops {
|
||||
// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
|
||||
// Jetson TX-2. This compiler does not support the offsetof() macro.
|
||||
#if defined(__aarch64__) && !defined(GOOGLE_L4T)
|
||||
|
||||
#include <stddef.h>
|
||||
// clang-format gets confused with this file and ends up formatting lines to
|
||||
// be larger than 80 characters. Turn off here and back on at the end of the
|
||||
// file.
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Place `<locale>` before <Python.h> to avoid build failures in macOS.
|
||||
#include <Python.h>
|
||||
#include <locale>
|
||||
|
||||
// We forward declare TFLite classes here to avoid exposing them to SWIG.
|
||||
namespace tflite {
|
||||
|
@ -29,6 +29,7 @@ py_library(
|
||||
"python/training/reg_adagrad_optimizer.py",
|
||||
"python/training/sign_decay.py",
|
||||
"python/training/variable_clipping_optimizer.py",
|
||||
"python/training/weight_decay_optimizers.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
@ -198,6 +199,25 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "weight_decay_optimizers_test",
|
||||
srcs = ["python/training/weight_decay_optimizers_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":opt_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "drop_stale_gradient_optimizer_test",
|
||||
srcs = ["python/training/drop_stale_gradient_optimizer_test.py"],
|
||||
|
@ -22,16 +22,17 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.opt.python.training.adamax import *
|
||||
from tensorflow.contrib.opt.python.training.addsign import *
|
||||
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.external_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.ggt import *
|
||||
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.model_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
|
||||
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.powersign import *
|
||||
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.model_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.ggt import *
|
||||
from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
@ -47,6 +48,10 @@ _allowed_symbols = [
|
||||
'LazyAdamOptimizer',
|
||||
'NadamOptimizer',
|
||||
'MovingAverageOptimizer',
|
||||
'MomentumWOptimizer',
|
||||
'AdamWOptimizer',
|
||||
'DecoupledWeightDecayExtension',
|
||||
'extend_with_decoupled_weight_decay',
|
||||
'ScipyOptimizerInterface',
|
||||
'VariableClippingOptimizer',
|
||||
'MultitaskOptimizerWrapper',
|
||||
|
@ -0,0 +1,362 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Base class to make optimizers weight decay ready."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import momentum as momentum_opt
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
class DecoupledWeightDecayExtension(object):
|
||||
"""This class allows to extend optimizers with decoupled weight decay.
|
||||
|
||||
It implements the decoupled weight decay described by Loshchilov & Hutter
|
||||
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
|
||||
decoupled from the optimization steps w.r.t. to the loss function.
|
||||
For SGD variants, this simplifies hyperparameter search since it decouples
|
||||
the settings of weight decay and learning rate.
|
||||
For adaptive gradient algorithms, it regularizes variables with large
|
||||
gradients more than L2 regularization would, which was shown to yield better
|
||||
training loss and generalization error in the paper above.
|
||||
|
||||
This class alone is not an optimizer but rather extends existing
|
||||
optimizers with decoupled weight decay. We explicitly define the two examples
|
||||
used in the above paper (SGDW and AdamW), but in general this can extend
|
||||
any OptimizerX by using
|
||||
`extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`.
|
||||
In order for it to work, it must be the first class the Optimizer with
|
||||
weight decay inherits from, e.g.
|
||||
|
||||
```python
|
||||
class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
|
||||
def __init__(self, weight_decay, *args, **kwargs):
|
||||
super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs).
|
||||
```
|
||||
|
||||
Note that this extension decays weights BEFORE applying the update based
|
||||
on the gradient, i.e. this extension only has the desired behaviour for
|
||||
optimizers which do not depend on the value of'var' in the update step!
|
||||
"""
|
||||
|
||||
def __init__(self, weight_decay, **kwargs):
|
||||
"""Construct the extension class that adds weight decay to an optimizer.
|
||||
|
||||
Args:
|
||||
weight_decay: A `Tensor` or a floating point value, the factor by which
|
||||
a variable is decayed in the update step.
|
||||
**kwargs: Optional list or tuple or set of `Variable` objects to
|
||||
decay.
|
||||
"""
|
||||
self._decay_var_list = None # is set in minimize or apply_gradients
|
||||
self._weight_decay = weight_decay
|
||||
# The tensors are initialized in call to _prepare
|
||||
self._weight_decay_tensor = None
|
||||
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
|
||||
|
||||
def minimize(self, loss, global_step=None, var_list=None,
|
||||
gate_gradients=optimizer.Optimizer.GATE_OP,
|
||||
aggregation_method=None, colocate_gradients_with_ops=False,
|
||||
name=None, grad_loss=None, decay_var_list=None):
|
||||
"""Add operations to minimize `loss` by updating `var_list` with decay.
|
||||
|
||||
This function is the same as Optimizer.minimize except that it allows to
|
||||
specify the variables that should be decayed using decay_var_list.
|
||||
If decay_var_list is None, all variables in var_list are decayed.
|
||||
|
||||
For more information see the documentation of Optimizer.minimize.
|
||||
|
||||
Args:
|
||||
loss: A `Tensor` containing the value to minimize.
|
||||
global_step: Optional `Variable` to increment by one after the
|
||||
variables have been updated.
|
||||
var_list: Optional list or tuple of `Variable` objects to update to
|
||||
minimize `loss`. Defaults to the list of variables collected in
|
||||
the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
|
||||
gate_gradients: How to gate the computation of gradients. Can be
|
||||
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
|
||||
aggregation_method: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
colocate_gradients_with_ops: If True, try colocating gradients with
|
||||
the corresponding op.
|
||||
name: Optional name for the returned operation.
|
||||
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
|
||||
decay_var_list: Optional list of decay variables.
|
||||
|
||||
Returns:
|
||||
An Operation that updates the variables in `var_list`. If `global_step`
|
||||
was not `None`, that operation also increments `global_step`.
|
||||
|
||||
"""
|
||||
self._decay_var_list = set(decay_var_list) if decay_var_list else False
|
||||
return super(DecoupledWeightDecayExtension, self).minimize(
|
||||
loss, global_step=global_step, var_list=var_list,
|
||||
gate_gradients=gate_gradients, aggregation_method=aggregation_method,
|
||||
colocate_gradients_with_ops=colocate_gradients_with_ops, name=name,
|
||||
grad_loss=grad_loss)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None,
|
||||
decay_var_list=None):
|
||||
"""Apply gradients to variables and decay the variables.
|
||||
|
||||
This function is the same as Optimizer.apply_gradients except that it
|
||||
allows to specify the variables that should be decayed using
|
||||
decay_var_list. If decay_var_list is None, all variables in var_list
|
||||
are decayed.
|
||||
|
||||
For more information see the documentation of Optimizer.apply_gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs as returned by
|
||||
`compute_gradients()`.
|
||||
global_step: Optional `Variable` to increment by one after the
|
||||
variables have been updated.
|
||||
name: Optional name for the returned operation. Default to the
|
||||
name passed to the `Optimizer` constructor.
|
||||
decay_var_list: Optional list of decay variables.
|
||||
|
||||
Returns:
|
||||
An `Operation` that applies the specified gradients. If `global_step`
|
||||
was not None, that operation also increments `global_step`.
|
||||
"""
|
||||
self._decay_var_list = set(decay_var_list) if decay_var_list else False
|
||||
return super(DecoupledWeightDecayExtension, self).apply_gradients(
|
||||
grads_and_vars, global_step=global_step, name=name)
|
||||
|
||||
def _prepare(self):
|
||||
weight_decay = self._weight_decay
|
||||
if callable(weight_decay):
|
||||
weight_decay = weight_decay()
|
||||
self._weight_decay_tensor = ops.convert_to_tensor(
|
||||
weight_decay, name="weight_decay")
|
||||
# Call the optimizers _prepare function.
|
||||
super(DecoupledWeightDecayExtension, self)._prepare()
|
||||
|
||||
def _decay_weights_op(self, var):
|
||||
if not self._decay_var_list or var in self._decay_var_list:
|
||||
return var.assign_sub(self._weight_decay * var, self._use_locking)
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
def _decay_weights_sparse_op(self, var, indices, scatter_add):
|
||||
if not self._decay_var_list or var in self._decay_var_list:
|
||||
return scatter_add(var, indices, -self._weight_decay * var,
|
||||
self._use_locking)
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
# Here, we overwrite the apply functions that the base optimizer calls.
|
||||
# super().apply_x resolves to the apply_x function of the BaseOptimizer.
|
||||
def _apply_dense(self, grad, var):
|
||||
with ops.control_dependencies([self._decay_weights_op(var)]):
|
||||
return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var)
|
||||
|
||||
def _resource_apply_dense(self, grad, var):
|
||||
with ops.control_dependencies([self._decay_weights_op(var)]):
|
||||
return super(DecoupledWeightDecayExtension, self)._resource_apply_dense(
|
||||
grad, var)
|
||||
|
||||
def _apply_sparse(self, grad, var):
|
||||
scatter_add = state_ops.scatter_add
|
||||
decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add)
|
||||
with ops.control_dependencies([decay_op]):
|
||||
return super(DecoupledWeightDecayExtension, self)._apply_sparse(
|
||||
grad, var)
|
||||
|
||||
def _resource_scatter_add(self, x, i, v, _=None):
|
||||
# last argument allows for one overflow argument, to have the same function
|
||||
# signature as state_ops.scatter_add
|
||||
with ops.control_dependencies(
|
||||
[resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
|
||||
return x.value()
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices):
|
||||
scatter_add = self._resource_scatter_add
|
||||
decay_op = self._decay_weights_sparse_op(var, indices, scatter_add)
|
||||
with ops.control_dependencies([decay_op]):
|
||||
return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse(
|
||||
grad, var, indices)
|
||||
|
||||
|
||||
def extend_with_decoupled_weight_decay(base_optimizer):
|
||||
"""Factory function returning an optimizer class with decoupled weight decay.
|
||||
|
||||
Returns an optimizer class. An instance of the returned class computes the
|
||||
update step of `base_optimizer` and additionally decays the weights.
|
||||
E.g., the class returned by
|
||||
`extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to
|
||||
`tf.contrib.opt.AdamWOptimizer`.
|
||||
|
||||
The API of the new optimizer class slightly differs from the API of the
|
||||
base optimizer:
|
||||
- The first argument to the constructor is the weight decay rate.
|
||||
- `minimize` and `apply_gradients` accept the optional keyword argument
|
||||
`decay_var_list`, which specifies the variables that should be decayed.
|
||||
If `None`, all variables that are optimized are decayed.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
# MyAdamW is a new class
|
||||
MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)
|
||||
# Create a MyAdamW object
|
||||
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
|
||||
sess.run(optimizer.minimize(loss, decay_variables=[var1, var2]))
|
||||
|
||||
Note that this extension decays weights BEFORE applying the update based
|
||||
on the gradient, i.e. this extension only has the desired behaviour for
|
||||
optimizers which do not depend on the value of'var' in the update step!
|
||||
```
|
||||
|
||||
Args:
|
||||
base_optimizer: An optimizer class that inherits from tf.train.Optimizer.
|
||||
|
||||
Returns:
|
||||
A new optimizer class that inherits from DecoupledWeightDecayExtension
|
||||
and base_optimizer.
|
||||
"""
|
||||
|
||||
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
|
||||
base_optimizer):
|
||||
"""Base_optimizer with decoupled weight decay.
|
||||
|
||||
This class computes the update step of `base_optimizer` and
|
||||
additionally decays the variable with the weight decay being decoupled from
|
||||
the optimization steps w.r.t. to the loss function, as described by
|
||||
Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf).
|
||||
For SGD variants, this simplifies hyperparameter search since
|
||||
it decouples the settings of weight decay and learning rate.
|
||||
For adaptive gradient algorithms, it regularizes variables with large
|
||||
gradients more than L2 regularization would, which was shown to yield
|
||||
better training loss and generalization error in the paper above.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_decay, *args, **kwargs):
|
||||
# super delegation is necessary here
|
||||
# pylint: disable=useless-super-delegation
|
||||
super(OptimizerWithDecoupledWeightDecay, self).__init__(
|
||||
weight_decay, *args, **kwargs)
|
||||
# pylint: enable=useless-super-delegation
|
||||
|
||||
return OptimizerWithDecoupledWeightDecay
|
||||
|
||||
|
||||
@tf_export("contrib.opt.MomentumWOptimizer")
|
||||
class MomentumWOptimizer(DecoupledWeightDecayExtension,
|
||||
momentum_opt.MomentumOptimizer):
|
||||
"""Optimizer that implements the Momentum algorithm with weight_decay.
|
||||
|
||||
This is an implementation of the SGDW optimizer described in "Fixing
|
||||
Weight Decay Regularization in Adam" by Loshchilov & Hutter
|
||||
(https://arxiv.org/abs/1711.05101)
|
||||
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
|
||||
It computes the update step of `train.MomentumOptimizer` and additionally
|
||||
decays the variable. Note that this is different from adding
|
||||
L2 regularization on the variables to the loss. Decoupling the weight decay
|
||||
from other hyperparameters (in particular the learning rate) simplifies
|
||||
hyperparameter search.
|
||||
|
||||
For further information see the documentation of the Momentum Optimizer.
|
||||
|
||||
Note that this optimizer can also be instantiated as
|
||||
```python
|
||||
extend_with_weight_decay(tf.train.MomentumOptimizer,
|
||||
weight_decay=weight_decay)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, weight_decay, learning_rate, momentum,
|
||||
use_locking=False, name="MomentumW", use_nesterov=False):
|
||||
"""Construct a new MomentumW optimizer.
|
||||
|
||||
For further information see the documentation of the Momentum Optimizer.
|
||||
|
||||
Args:
|
||||
weight_decay: A `Tensor` or a floating point value. The weight decay.
|
||||
learning_rate: A `Tensor` or a floating point value. The learning rate.
|
||||
momentum: A `Tensor` or a floating point value. The momentum.
|
||||
use_locking: If `True` use locks for update operations.
|
||||
name: Optional name prefix for the operations created when applying
|
||||
gradients. Defaults to "Momentum".
|
||||
use_nesterov: If `True` use Nesterov Momentum.
|
||||
See [Sutskever et al., 2013](
|
||||
http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
|
||||
This implementation always computes gradients at the value of the
|
||||
variable(s) passed to the optimizer. Using Nesterov Momentum makes the
|
||||
variable(s) track the values called `theta_t + mu*v_t` in the paper.
|
||||
|
||||
@compatibility(eager)
|
||||
When eager execution is enabled, learning_rate, weight_decay and momentum
|
||||
can each be a callable that takes no arguments and returns the actual value
|
||||
to use. This can be useful for changing these values across different
|
||||
invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
super(MomentumWOptimizer, self).__init__(
|
||||
weight_decay, learning_rate=learning_rate, momentum=momentum,
|
||||
use_locking=use_locking, name=name, use_nesterov=use_nesterov)
|
||||
|
||||
|
||||
@tf_export("contrib.opt.AdamWOptimizer")
|
||||
class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
|
||||
"""Optimizer that implements the Adam algorithm with weight decay.
|
||||
|
||||
This is an implementation of the AdamW optimizer described in "Fixing
|
||||
Weight Decay Regularization in Adam" by Loshchilov & Hutter
|
||||
(https://arxiv.org/abs/1711.05101)
|
||||
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
|
||||
|
||||
It computes the update step of `train.AdamOptimizer` and additionally decays
|
||||
the variable. Note that this is different from adding L2 regularization on
|
||||
the variables to the loss: it regularizes variables with large
|
||||
gradients more than L2 regularization would, which was shown to yield better
|
||||
training loss and generalization error in the paper above.
|
||||
|
||||
For further information see the documentation of the Adam Optimizer.
|
||||
|
||||
Note that this optimizer can also be instantiated as
|
||||
```python
|
||||
extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999,
|
||||
epsilon=1e-8, use_locking=False, name="AdamW"):
|
||||
"""Construct a new AdamW optimizer.
|
||||
|
||||
For further information see the documentation of the Adam Optimizer.
|
||||
|
||||
Args:
|
||||
weight_decay: A `Tensor` or a floating point value. The weight decay.
|
||||
learning_rate: A Tensor or a floating point value. The learning rate.
|
||||
beta1: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 1st moment estimates.
|
||||
beta2: A float value or a constant float tensor.
|
||||
The exponential decay rate for the 2nd moment estimates.
|
||||
epsilon: A small constant for numerical stability. This epsilon is
|
||||
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
||||
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
||||
use_locking: If True use locks for update operations.
|
||||
name: Optional name for the operations created when applying gradients.
|
||||
Defaults to "Adam".
|
||||
"""
|
||||
super(AdamWOptimizer, self).__init__(
|
||||
weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2,
|
||||
epsilon=epsilon, use_locking=use_locking, name=name)
|
@ -0,0 +1,188 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for optimizers with weight decay."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.opt.python.training import weight_decay_optimizers
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam
|
||||
|
||||
WEIGHT_DECAY = 0.01
|
||||
|
||||
|
||||
def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9,
|
||||
beta2=0.999, epsilon=1e-8):
|
||||
lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t)
|
||||
|
||||
m_t = beta1 * m + (1 - beta1) * g_t
|
||||
v_t = beta2 * v + (1 - beta2) * g_t * g_t
|
||||
|
||||
param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) -
|
||||
(param * WEIGHT_DECAY))
|
||||
return param_t, m_t, v_t
|
||||
|
||||
|
||||
def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_):
|
||||
# v, t are not needed for momentum optimizer
|
||||
m = momentum * m + g_t
|
||||
param_t = param - lr * m - param * WEIGHT_DECAY
|
||||
return param_t, m, None
|
||||
|
||||
|
||||
class WeightDecayOptimizerTest(test.TestCase):
|
||||
|
||||
def doTest(self, optimizer, update_fn, optimizer_name, slot_name,
|
||||
use_resource=False, do_sparse=False):
|
||||
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||
with self.test_session(graph=ops.Graph()):
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||
|
||||
if use_resource:
|
||||
var0 = resource_variable_ops.ResourceVariable(
|
||||
var0_np, name="var0_%d" % i)
|
||||
var1 = resource_variable_ops.ResourceVariable(
|
||||
var1_np, name="var1_%d" % i)
|
||||
else:
|
||||
var0 = variables.Variable(var0_np)
|
||||
var1 = variables.Variable(var1_np)
|
||||
|
||||
if do_sparse:
|
||||
grads0_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads0 = ops.IndexedSlices(constant_op.constant(grads0_np),
|
||||
constant_op.constant(grads0_np_indices),
|
||||
constant_op.constant([2]))
|
||||
grads1_np_indices = np.array([0, 1], dtype=np.int32)
|
||||
grads1 = ops.IndexedSlices(constant_op.constant(grads1_np),
|
||||
constant_op.constant(grads1_np_indices),
|
||||
constant_op.constant([2]))
|
||||
else:
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
|
||||
opt = optimizer()
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
|
||||
if not context.executing_eagerly():
|
||||
with ops.Graph().as_default():
|
||||
# Shouldn't return non-slot variables from other graphs.
|
||||
self.assertEqual(0, len(opt.variables()))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||
|
||||
# Run 3 steps of the optimizer
|
||||
for t in range(1, 4):
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(update)
|
||||
elif t > 1:
|
||||
opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
|
||||
var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0)
|
||||
var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||
if use_resource:
|
||||
self.assertEqual("var0_%d/%s:0" % (i, optimizer_name),
|
||||
opt.get_slot(var=var0, name=slot_name).name)
|
||||
|
||||
|
||||
class AdamWOptimizerTest(WeightDecayOptimizerTest):
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer():
|
||||
return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY)
|
||||
|
||||
def testSparse(self):
|
||||
self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
|
||||
use_resource=False, do_sparse=True)
|
||||
|
||||
def testResourceSparse(self):
|
||||
self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
|
||||
use_resource=True, do_sparse=True)
|
||||
|
||||
def testBasic(self):
|
||||
self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
|
||||
use_resource=False)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(reset_test=True)
|
||||
def testResourceBasic(self):
|
||||
self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
|
||||
use_resource=True)
|
||||
|
||||
|
||||
class MomentumWOptimizerTest(WeightDecayOptimizerTest):
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer():
|
||||
return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9)
|
||||
|
||||
def testSparse(self):
|
||||
self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
|
||||
"momentum", use_resource=False, do_sparse=True)
|
||||
|
||||
def testResourceSparse(self):
|
||||
self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
|
||||
"momentum", use_resource=True, do_sparse=True)
|
||||
|
||||
def testBasic(self):
|
||||
self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
|
||||
"momentum", use_resource=False)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(reset_test=True)
|
||||
def testResourceBasic(self):
|
||||
self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
|
||||
"momentum", use_resource=True)
|
||||
|
||||
|
||||
class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer():
|
||||
adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay(
|
||||
adam.AdamOptimizer)
|
||||
return adamw(WEIGHT_DECAY)
|
||||
|
||||
def testBasic(self):
|
||||
self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
|
||||
use_resource=False)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(reset_test=True)
|
||||
def testResourceBasic(self):
|
||||
self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
|
||||
use_resource=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
|
||||
|
||||
def conjugate_gradient(operator,
|
||||
|
@ -49,7 +49,6 @@ tf_cuda_cc_test(
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_trt_engine_op.so",
|
||||
srcs = [
|
||||
"ops/trt_calib_op.cc",
|
||||
"ops/trt_engine_op.cc",
|
||||
],
|
||||
deps = [
|
||||
@ -76,11 +75,9 @@ tf_cuda_library(
|
||||
cc_library(
|
||||
name = "trt_engine_op_kernel",
|
||||
srcs = [
|
||||
"kernels/trt_calib_op.cc",
|
||||
"kernels/trt_engine_op.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"kernels/trt_calib_op.h",
|
||||
"kernels/trt_engine_op.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
@ -89,20 +86,22 @@ cc_library(
|
||||
":trt_logging",
|
||||
":trt_plugins",
|
||||
":trt_resources",
|
||||
":trt_conversion",
|
||||
":utils",
|
||||
"//tensorflow/core:gpu_headers_lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:nv_infer",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
# TODO(laigd)
|
||||
# TODO(laigd): fix this by merging header file in cc file.
|
||||
alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"trt_engine_op",
|
||||
"trt_calib_op",
|
||||
],
|
||||
)
|
||||
|
||||
@ -122,7 +121,6 @@ tf_gen_op_wrapper_py(
|
||||
name = "trt_engine_op",
|
||||
gen_locally = True,
|
||||
deps = [
|
||||
":trt_calib_op_op_lib",
|
||||
":trt_engine_op_op_lib",
|
||||
":trt_logging",
|
||||
":trt_shape_function",
|
||||
@ -140,7 +138,6 @@ tf_custom_op_py_library(
|
||||
kernels = [
|
||||
":trt_engine_op_kernel",
|
||||
":trt_engine_op_op_lib",
|
||||
":trt_calib_op_op_lib",
|
||||
":trt_shape_function",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
@ -191,7 +188,6 @@ tf_py_wrap_cc(
|
||||
deps = [
|
||||
":trt_conversion",
|
||||
":trt_engine_op_kernel",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
@ -211,6 +207,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = [
|
||||
":trt_logging",
|
||||
":utils",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
@ -237,12 +234,12 @@ tf_cuda_library(
|
||||
":trt_plugins",
|
||||
":trt_logging",
|
||||
":trt_resources",
|
||||
":utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:graph",
|
||||
@ -343,3 +340,8 @@ py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
hdrs = ["convert/utils.h"],
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -30,29 +30,60 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
|
||||
// This method converts an already generated calibration graph which was used in
|
||||
// calibration runs to an inference graph
|
||||
tensorflow::Status ConvertCalibGraphToInferGraph(
|
||||
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def);
|
||||
struct ConversionParams {
|
||||
ConversionParams()
|
||||
: input_graph_def(nullptr),
|
||||
max_batch_size(1),
|
||||
max_workspace_size_bytes(1 << 30),
|
||||
output_graph_def(nullptr),
|
||||
precision_mode(1),
|
||||
minimum_segment_size(3),
|
||||
graph_properties(nullptr),
|
||||
cluster(nullptr),
|
||||
is_dyn_op(false),
|
||||
fixed_input_size(true),
|
||||
max_cached_engines(1) {}
|
||||
const tensorflow::GraphDef* input_graph_def;
|
||||
const std::vector<string>* output_names;
|
||||
size_t max_batch_size;
|
||||
size_t max_workspace_size_bytes;
|
||||
tensorflow::GraphDef* output_graph_def;
|
||||
int precision_mode;
|
||||
int minimum_segment_size;
|
||||
const tensorflow::grappler::GraphProperties* graph_properties;
|
||||
const tensorflow::grappler::Cluster* cluster;
|
||||
bool is_dyn_op; // Whether to create engine on conversion or execution time
|
||||
bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed
|
||||
int max_cached_engines; // maximum number of cached engines
|
||||
std::vector<int> cached_engine_batches; // list of cached engines
|
||||
};
|
||||
|
||||
// max_batch_size: maximum batch size which can be used for inference for
|
||||
// optimization targets inference run with max batch size.
|
||||
// max_workspace_size_bytes: The upper bound of memory allowance for
|
||||
// engine building.
|
||||
// This method extracts calibration information from the resource managers
|
||||
// and puts them in to engine nodedefs.
|
||||
tensorflow::Status ConvertCalibGraphToInferGraph(
|
||||
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def,
|
||||
bool is_dyn_op);
|
||||
|
||||
// - max_batch_size: maximum batch size which can be used for inference for
|
||||
// optimization targets inference run with max batch size.
|
||||
// - max_workspace_size_bytes: The upper bound of memory allowance for engine
|
||||
// building.
|
||||
tensorflow::Status ConvertGraphDefToTensorRT(
|
||||
const tensorflow::GraphDef& graph_def,
|
||||
const std::vector<string>& output_names, size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
|
||||
int precision_mode, int minimum_segment_size);
|
||||
int precision_mode = 1, int minimum_segment_size = 3,
|
||||
bool is_dyn_op = false, int max_cached_engines = 1,
|
||||
std::vector<int> cached_engine_batches = {});
|
||||
|
||||
// Method to call from optimization pass
|
||||
tensorflow::Status ConvertAfterShapes(
|
||||
const tensorflow::GraphDef& graph, const std::vector<string>& output_names,
|
||||
size_t max_batch_size, size_t max_workspace_size_bytes,
|
||||
tensorflow::GraphDef* new_graph_def, int precision_mode,
|
||||
int minimum_segment_size,
|
||||
const tensorflow::grappler::GraphProperties& graph_properties,
|
||||
const tensorflow::grappler::Cluster* cluster);
|
||||
tensorflow::Status ConvertAfterShapes(ConversionParams& params);
|
||||
|
||||
// Return compile time TensorRT library version information.
|
||||
std::vector<int> GetLinkedTensorRTVersion();
|
||||
|
||||
// Return runtime time TensorRT library version information.
|
||||
std::vector<int> GetLoadedTensorRTVersion();
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
@ -25,7 +24,9 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/utils.h"
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
|
||||
@ -37,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -54,8 +56,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace convert {
|
||||
using ::tensorflow::str_util::Split;
|
||||
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
namespace {
|
||||
|
||||
inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
|
||||
@ -121,12 +126,10 @@ static std::vector<std::pair<int, int>> CreateSamePadding(
|
||||
|
||||
string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
|
||||
size_t last_scope_separator = 0;
|
||||
for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) {
|
||||
if (op_name_a[i] != op_name_b[i]) {
|
||||
break;
|
||||
} else if (op_name_a[i] == '/') {
|
||||
last_scope_separator = i + 1;
|
||||
}
|
||||
const size_t min_size = std::min(op_name_a.size(), op_name_b.size());
|
||||
for (size_t i = 0; i < min_size; ++i) {
|
||||
if (op_name_a[i] != op_name_b[i]) break;
|
||||
if (op_name_a[i] == '/') last_scope_separator = i + 1;
|
||||
}
|
||||
return op_name_a.substr(0, last_scope_separator);
|
||||
}
|
||||
@ -417,20 +420,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
|
||||
}
|
||||
}
|
||||
|
||||
struct InferDeleter {
|
||||
template <typename T>
|
||||
void operator()(T* obj) const {
|
||||
if (obj) {
|
||||
obj->destroy();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline std::shared_ptr<T> infer_object(T* obj) {
|
||||
return std::shared_ptr<T>(obj, InferDeleter());
|
||||
}
|
||||
|
||||
class Converter;
|
||||
|
||||
using OpConverter =
|
||||
@ -444,7 +433,7 @@ class Converter {
|
||||
OpConverter plugin_converter_;
|
||||
nvinfer1::INetworkDefinition* trt_network_;
|
||||
std::list<std::vector<uint8_t>> temp_bufs_;
|
||||
tensorflow::tensorrt::TRTWeightStore* weight_store_;
|
||||
TRTWeightStore* weight_store_;
|
||||
bool fp16_;
|
||||
void register_op_converters();
|
||||
tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
|
||||
@ -486,11 +475,11 @@ class Converter {
|
||||
|
||||
public:
|
||||
explicit Converter(nvinfer1::INetworkDefinition* trt_network,
|
||||
tensorflow::tensorrt::TRTWeightStore* ws, bool fp16)
|
||||
TRTWeightStore* ws, bool fp16)
|
||||
: trt_network_(trt_network), weight_store_(ws), fp16_(fp16) {
|
||||
this->register_op_converters();
|
||||
}
|
||||
tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; }
|
||||
TRTWeightStore* weight_store() { return weight_store_; }
|
||||
TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
|
||||
nvinfer1::Dims shape) {
|
||||
TRT_ShapedWeights weights(type, nullptr, shape);
|
||||
@ -2140,559 +2129,265 @@ void Converter::register_op_converters() {
|
||||
|
||||
} // namespace
|
||||
|
||||
tensorflow::Status ConvertCalibrationNodeToEngineNode(
|
||||
tensorflow::Graph& graph, tensorflow::Node* c_node) {
|
||||
const auto ndef = c_node->def();
|
||||
tensorflow::Status ConvertGraphDefToEngine(
|
||||
const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size,
|
||||
size_t max_workspace_size_bytes,
|
||||
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
|
||||
Logger* logger, nvinfer1::IGpuAllocator* allocator,
|
||||
TRTInt8Calibrator* calibrator,
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
|
||||
bool* convert_successfully) {
|
||||
engine->reset();
|
||||
if (convert_successfully) *convert_successfully = false;
|
||||
|
||||
TFAttrs attrs(ndef);
|
||||
std::vector<string> segment_nodes(
|
||||
attrs.get<std::vector<string>>("segment_nodes"));
|
||||
std::vector<string> output_nodes(
|
||||
attrs.get<std::vector<string>>("segment_output_names"));
|
||||
std::vector<string> input_names(
|
||||
attrs.get<std::vector<string>>("input_names"));
|
||||
string res_name = attrs.get<string>("resource_name");
|
||||
VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name;
|
||||
string engine_name = "my_trt_op";
|
||||
{
|
||||
const auto node_id = tensorflow::str_util::Split(res_name, "_");
|
||||
engine_name += node_id.back();
|
||||
}
|
||||
std::map<string, tensorflow::Node*> node_maps;
|
||||
|
||||
for (auto n : graph.op_nodes()) {
|
||||
node_maps.insert({n->name(), n});
|
||||
}
|
||||
std::set<int> subgraph_ids;
|
||||
for (const auto internal_node : segment_nodes) {
|
||||
subgraph_ids.insert(node_maps.at(internal_node)->id());
|
||||
}
|
||||
if (VLOG_IS_ON(2)) {
|
||||
string node_names = StrCat(c_node->name(), " segment nodes= ");
|
||||
|
||||
for (const auto& node_name : segment_nodes) {
|
||||
StrAppend(&node_names, node_name, ", ");
|
||||
}
|
||||
VLOG(2) << node_names;
|
||||
}
|
||||
|
||||
VLOG(1) << "Output Nodes:";
|
||||
std::vector<tensorflow::DataType> out_types;
|
||||
std::vector<const tensorflow::Edge*> out_edges;
|
||||
|
||||
for (auto& i : output_nodes) {
|
||||
auto node_port = tensorflow::str_util::Split(i, ":");
|
||||
VLOG(1) << " " << i << " in graph " << node_maps.count(i);
|
||||
auto out_node_name = node_port.at(0);
|
||||
if (node_port.size() > 1) {
|
||||
VLOG(1) << "Multi port output" << node_port.at(0) << " "
|
||||
<< node_port.at(1) << " size=" << node_port.size();
|
||||
}
|
||||
auto node_it = node_maps.find(out_node_name);
|
||||
if (node_it != node_maps.end()) {
|
||||
tensorflow::Node* out_node = node_it->second;
|
||||
int port = 0;
|
||||
if (node_port.size() == 2) {
|
||||
port = std::strtoul(node_port.at(1).c_str(), nullptr, 10);
|
||||
out_types.push_back(out_node->output_type(port));
|
||||
} else {
|
||||
out_types.push_back(out_node->output_type(0));
|
||||
}
|
||||
for (auto out_edge : out_node->out_edges()) {
|
||||
if (subgraph_ids.count(out_edge->dst()->id()))
|
||||
continue; // skip internal edges;
|
||||
if (out_edge->src_output() == port) {
|
||||
out_edges.push_back(out_edge);
|
||||
VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":"
|
||||
<< out_edge->src_output() << " -> " << out_edge->dst()->name()
|
||||
<< ":" << out_edge->dst_input();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(WARNING) << " couldn't find output node " << out_node_name;
|
||||
}
|
||||
}
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << c_node->name() << " Input Nodes:";
|
||||
for (auto& i : input_names) {
|
||||
VLOG(1) << " Input " << i << " in graph " << node_maps.count(i);
|
||||
}
|
||||
}
|
||||
auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
|
||||
auto resmgr = trt_rm->getManager("TRTCalibOps");
|
||||
tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr;
|
||||
auto status = resmgr->Lookup(res_name, res_name, &calib_res);
|
||||
if (!status.ok() || !calib_res->calibrator_) {
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"You must run calibration"
|
||||
" and inference conversion in the same process");
|
||||
}
|
||||
|
||||
calib_res->calibrator_->setDone();
|
||||
calib_res->thr_->join();
|
||||
delete calib_res->thr_;
|
||||
if (!calib_res->engine_) {
|
||||
LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run "
|
||||
"calibration graph?";
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Calibration graph needs to be executed on"
|
||||
" calibration data before convertsion to inference graph");
|
||||
}
|
||||
auto weight_rmgr = trt_rm->getManager("WeightStore");
|
||||
TF_CHECK_OK(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(
|
||||
res_name, res_name));
|
||||
auto engine_plan = calib_res->engine_->serialize();
|
||||
calib_res->engine_->destroy();
|
||||
calib_res->network_->destroy();
|
||||
calib_res->builder_->destroy();
|
||||
calib_res->thr_ = nullptr;
|
||||
calib_res->engine_ = nullptr;
|
||||
calib_res->builder_ = nullptr;
|
||||
tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
|
||||
std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
|
||||
income_edges.resize(c_node->num_inputs());
|
||||
for (const auto in_edge : c_node->in_edges()) {
|
||||
auto src = in_edge->src();
|
||||
int dest_port = in_edge->dst_input();
|
||||
VLOG(1) << "Incoming connection " << src->name() << ":"
|
||||
<< in_edge->src_output() << " -> " << c_node->name() << ":"
|
||||
<< dest_port;
|
||||
income_edges.at(dest_port) = {src->name(), in_edge->src_output(),
|
||||
c_node->input_type(dest_port)};
|
||||
}
|
||||
tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
|
||||
income_edges);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
for (const auto& inp : input_list) {
|
||||
VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " "
|
||||
<< tensorflow::DataTypeString(inp.data_type);
|
||||
}
|
||||
}
|
||||
op_builder.Input(input_list);
|
||||
tensorflow::NodeDef engine_node;
|
||||
const char* engine_plan_data = static_cast<const char*>(engine_plan->data());
|
||||
string engine_plan_string(engine_plan_data,
|
||||
engine_plan_data + engine_plan->size());
|
||||
status = op_builder.Attr("serialized_engine", engine_plan_string)
|
||||
.Attr("input_nodes", input_names)
|
||||
.Attr("output_nodes", output_nodes)
|
||||
.Attr("OutT", out_types)
|
||||
.Finalize(&engine_node);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Engine Node creation failed";
|
||||
return status;
|
||||
}
|
||||
auto trt_engine_node = graph.AddNode(engine_node, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
std::map<string, int> port_map;
|
||||
for (size_t t = 0; t < output_nodes.size(); t++) {
|
||||
port_map.insert({output_nodes.at(t), t});
|
||||
}
|
||||
for (auto& i : out_edges) {
|
||||
string s(i->src()->name());
|
||||
if (i->src_output()) StrAppend(&s, ":", i->src_output());
|
||||
int out_port = port_map.at(s);
|
||||
VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port
|
||||
<< " -> " << i->dst()->name() << ":" << i->dst_input();
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input()));
|
||||
}
|
||||
for (const auto ed : trt_engine_node->in_edges()) {
|
||||
VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output()
|
||||
<< " -> " << ed->dst()->name() << ":" << ed->dst_input();
|
||||
}
|
||||
for (const auto ed : trt_engine_node->out_edges()) {
|
||||
VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output()
|
||||
<< " -> " << ed->dst()->name() << ":" << ed->dst_input();
|
||||
}
|
||||
VLOG(1) << "Segment nodes:";
|
||||
for (auto& i : segment_nodes) {
|
||||
VLOG(1) << " " << i << " in graph " << node_maps.count(i);
|
||||
auto it = node_maps.find(i);
|
||||
if (it != node_maps.end()) {
|
||||
graph.RemoveNode(it->second);
|
||||
}
|
||||
}
|
||||
graph.RemoveNode(c_node);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ReverseTopologicalSort(
|
||||
const tensorrt::convert::SubGraphParams& s,
|
||||
std::list<tensorflow::Node*>* order) {
|
||||
std::vector<tensorflow::Node*> order_vec;
|
||||
tensorflow::GetPostOrder(s.graph, &order_vec);
|
||||
// Select just the subgraph
|
||||
for (tensorflow::Node* node : order_vec) {
|
||||
if (s.subgraph_node_ids.count(node->id())) {
|
||||
// We want topological order to contstruct the
|
||||
// network layer by layer
|
||||
order->push_front(node);
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status SetInputList(
|
||||
const tensorrt::convert::SubGraphParams& s,
|
||||
tensorflow::NodeDefBuilder* op_builder,
|
||||
const std::vector<string>* input_names,
|
||||
std::vector<tensorflow::DataType>* input_dtypes) {
|
||||
std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
|
||||
VLOG(2) << "input edge size: " << input_names->size();
|
||||
for (size_t i = 0; i < input_names->size(); ++i) {
|
||||
VLOG(2) << "input edges: " << i << " " << input_names->at(i);
|
||||
int output_idx = s.input_inds.at(i).second;
|
||||
// we wired up the input here already, it is redundant to do it again in
|
||||
// ConvertSubGraphToTensorRT(convert_graph.cc)
|
||||
auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
|
||||
input_names->at(i), output_idx, input_dtypes->at(i));
|
||||
income_edges.push_back(incoming_edge);
|
||||
}
|
||||
tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
|
||||
income_edges);
|
||||
op_builder->Input(input_list);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
string SubgraphNameScopeGenerator(const std::list<tensorflow::Node*>* order) {
|
||||
string subgraph_name_scope;
|
||||
if (!order->empty()) {
|
||||
subgraph_name_scope = order->front()->name();
|
||||
}
|
||||
for (const tensorflow::Node* node : *order) {
|
||||
subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name());
|
||||
}
|
||||
// TODO(sami,ben,jie): proper naming!
|
||||
return subgraph_name_scope;
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertSubgraph(
|
||||
Converter& converter, tensorrt::convert::SubGraphParams& s,
|
||||
std::list<tensorflow::Node*>* order, std::vector<string>* input_names,
|
||||
std::vector<tensorflow::DataType>* input_dtypes,
|
||||
std::vector<string>* output_names,
|
||||
std::vector<tensorflow::DataType>* output_dtypes,
|
||||
const string& engine_name) {
|
||||
std::set<string> added_tensors;
|
||||
for (const std::pair<int, int>& input : s.input_inds) {
|
||||
VLOG(2) << "parsing input. Node id= " << input.first;
|
||||
int node_id = input.first;
|
||||
int output_idx = input.second;
|
||||
tensorflow::Node* node = s.graph.FindNodeId(node_id);
|
||||
auto node_name = node->name();
|
||||
// input_names should use the node name in the graph
|
||||
// here it should be the input tensor name -> matching the binding
|
||||
// insert original node name without port
|
||||
auto tensor_name = node_name;
|
||||
if (output_idx != 0) {
|
||||
tensor_name = StrCat(tensor_name, ":", output_idx);
|
||||
}
|
||||
|
||||
VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name
|
||||
<< " idx: " << output_idx;
|
||||
|
||||
auto shape_inference_node_name = node_name;
|
||||
auto shape_inference_output_idx = output_idx;
|
||||
// rewire the shape inference to original node in the graph
|
||||
if (s.output_edge_map->count(tensor_name)) {
|
||||
shape_inference_node_name = s.output_edge_map->at(tensor_name).second;
|
||||
shape_inference_output_idx = s.output_edge_map->at(tensor_name).first;
|
||||
}
|
||||
if (shape_inference_output_idx < 0) continue;
|
||||
VLOG(2) << "shapeinference name: " << shape_inference_node_name
|
||||
<< " idx: " << shape_inference_output_idx;
|
||||
|
||||
if (!s.graph_properties.HasOutputProperties(shape_inference_node_name))
|
||||
return tensorflow::errors::Internal("failed to find input node: " +
|
||||
shape_inference_node_name);
|
||||
|
||||
auto op_info_vec =
|
||||
s.graph_properties.GetOutputProperties(shape_inference_node_name);
|
||||
if (static_cast<int>(op_info_vec.size()) <= shape_inference_output_idx)
|
||||
return tensorflow::errors::Internal(
|
||||
"accessing output index of: ", shape_inference_output_idx,
|
||||
", at node: ", shape_inference_node_name,
|
||||
" with output entry from shape_map: ", op_info_vec.size());
|
||||
|
||||
auto op_info = op_info_vec.at(shape_inference_output_idx);
|
||||
tensorflow::DataType tf_dtype = op_info.dtype();
|
||||
|
||||
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
|
||||
auto type_status = ConvertDType(tf_dtype, &dtype);
|
||||
if (type_status != tensorflow::Status::OK()) {
|
||||
LOG(WARNING) << "Type conversion failed for " << node_name;
|
||||
return type_status;
|
||||
}
|
||||
|
||||
VLOG(2) << "Accessing output index of: " << output_idx
|
||||
<< ", at node: " << node_name
|
||||
<< " with output entry from shape_map: " << op_info_vec.size();
|
||||
// TODO(ben,jie): update TRT input format/dimension
|
||||
nvinfer1::DimsCHW input_dim_pseudo_chw;
|
||||
for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1;
|
||||
|
||||
// TODO(jie): TRT 3.x only support 4 dimensional input tensor.
|
||||
// update the code once TRT 4.0 comes out.
|
||||
if (op_info.shape().dim_size() != 4) {
|
||||
string err_str = "Require 4 dimensional input.";
|
||||
StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ",
|
||||
shape_inference_node_name);
|
||||
return tensorflow::errors::Unimplemented(err_str);
|
||||
}
|
||||
|
||||
for (int i = 1; i < op_info.shape().dim_size(); i++) {
|
||||
VLOG(2) << "dimension: " << i
|
||||
<< " , size: " << op_info.shape().dim(i).size();
|
||||
input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size();
|
||||
}
|
||||
|
||||
// TODO(ben,jie): proper way to restore input tensor name?
|
||||
auto input_tensor_name = node_name;
|
||||
if (output_idx != 0) {
|
||||
input_tensor_name = StrCat(node_name, ":", output_idx);
|
||||
}
|
||||
if (added_tensors.count(input_tensor_name)) continue;
|
||||
added_tensors.insert(input_tensor_name);
|
||||
input_names->push_back(input_tensor_name);
|
||||
input_dtypes->push_back(tf_dtype);
|
||||
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
|
||||
input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
|
||||
|
||||
if (!input_tensor)
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Failed to create Input layer");
|
||||
VLOG(2) << "Input tensor name :" << input_tensor_name;
|
||||
|
||||
if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
|
||||
return tensorflow::errors::AlreadyExists(
|
||||
"Output tensor already exists for op: " + input_tensor_name);
|
||||
}
|
||||
|
||||
for (const tensorflow::Node* node : *order) {
|
||||
const tensorflow::NodeDef& node_def = node->def();
|
||||
VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
|
||||
TF_RETURN_IF_ERROR(converter.convert_node(node_def));
|
||||
}
|
||||
|
||||
VLOG(2) << "Finished conversion";
|
||||
|
||||
// Gather output metadata
|
||||
int trt_engine_op_output_idx = 0;
|
||||
added_tensors.clear();
|
||||
for (const std::pair<int, int>& output : s.output_inds) {
|
||||
int node_id = output.first;
|
||||
int output_idx = output.second;
|
||||
tensorflow::Node* node = s.graph.FindNodeId(node_id);
|
||||
string op_name = node->name();
|
||||
string tensor_name = op_name;
|
||||
|
||||
s.output_edge_map->insert(
|
||||
{trt_engine_op_output_idx == 0
|
||||
? engine_name
|
||||
: StrCat(engine_name, ":", trt_engine_op_output_idx),
|
||||
{output_idx, tensor_name}});
|
||||
trt_engine_op_output_idx++;
|
||||
if (output_idx != 0)
|
||||
tensorflow::strings::StrAppend(&tensor_name, ":", output_idx);
|
||||
VLOG(2) << "Output tensor name: " << tensor_name;
|
||||
if (added_tensors.count(tensor_name)) continue;
|
||||
added_tensors.insert(tensor_name);
|
||||
output_names->push_back(tensor_name);
|
||||
auto tensor_or_weights = converter.get_tensor(tensor_name);
|
||||
if (!tensor_or_weights.is_tensor()) {
|
||||
return tensorflow::errors::InvalidArgument("Output node '" + tensor_name +
|
||||
"' is weights not tensor");
|
||||
}
|
||||
nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
|
||||
if (!tensor) {
|
||||
return tensorflow::errors::NotFound("Output tensor not found: " +
|
||||
tensor_name);
|
||||
}
|
||||
converter.network()->markOutput(*tensor);
|
||||
tensorflow::DataType tf_dtype = node->output_type(output_idx);
|
||||
output_dtypes->push_back(tf_dtype);
|
||||
nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
|
||||
TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
|
||||
tensor->setType(trt_dtype);
|
||||
}
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
|
||||
// Visit nodes in reverse topological order and construct the TRT network.
|
||||
// Toposort
|
||||
std::list<tensorflow::Node*> order;
|
||||
TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order));
|
||||
|
||||
static int static_id = 0;
|
||||
string subgraph_name_scope = SubgraphNameScopeGenerator(&order);
|
||||
// TODO(sami,ben,jie): proper naming!
|
||||
string calib_op_name =
|
||||
StrCat(subgraph_name_scope, "my_trt_calib_op_", static_id);
|
||||
string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id);
|
||||
static_id++;
|
||||
|
||||
auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance();
|
||||
auto op_rmgr = trt_rmgr->getManager("TRTCalibOps");
|
||||
auto op_res = new tensorflow::tensorrt::TRTCalibrationResource();
|
||||
TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res));
|
||||
op_res->logger_ = new tensorflow::tensorrt::Logger();
|
||||
cudaSetDevice(s.cuda_gpu_id_);
|
||||
op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_));
|
||||
op_res->allocator_ = s.allocator_;
|
||||
// Create the builder.
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder(
|
||||
nvinfer1::createInferBuilder(*logger));
|
||||
builder->setMaxBatchSize(max_batch_size);
|
||||
// TODO(aaroey): use the allocator to allocate the TRT workspace.
|
||||
builder->setMaxWorkspaceSize(max_workspace_size_bytes);
|
||||
#if NV_TENSORRT_MAJOR > 3
|
||||
op_res->builder_->setGpuAllocator(s.allocator_.get());
|
||||
builder->setGpuAllocator(allocator);
|
||||
#endif
|
||||
if (!op_res->builder_) {
|
||||
return tensorflow::errors::Internal(
|
||||
"failed to create TensorRT builder object");
|
||||
if (precision_mode == FP16MODE) {
|
||||
builder->setHalf2Mode(true);
|
||||
} else if (precision_mode == INT8MODE) {
|
||||
builder->setInt8Mode(true);
|
||||
builder->setInt8Calibrator(calibrator);
|
||||
}
|
||||
|
||||
op_res->network_ = op_res->builder_->createNetwork();
|
||||
if (!op_res->network_) {
|
||||
return tensorflow::errors::Internal(
|
||||
"failed to create TensorRT network object");
|
||||
}
|
||||
|
||||
// Build the network
|
||||
auto weight_rmgr = trt_rmgr->getManager("WeightStore");
|
||||
auto ws = new tensorflow::tensorrt::TRTWeightStore();
|
||||
TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws));
|
||||
Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE);
|
||||
|
||||
std::vector<string> input_names;
|
||||
std::vector<tensorflow::DataType> input_dtypes;
|
||||
std::vector<string> output_names;
|
||||
std::vector<tensorflow::DataType> output_dtypes;
|
||||
TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names,
|
||||
&input_dtypes, &output_names,
|
||||
&output_dtypes, engine_name));
|
||||
|
||||
VLOG(2) << "Finished processing outputs";
|
||||
|
||||
// Build the engine
|
||||
op_res->builder_->setMaxBatchSize(s.max_batch_size);
|
||||
op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes);
|
||||
VLOG(0) << "Max batch size= " << s.max_batch_size
|
||||
<< " max workspace size= " << s.max_workspace_size_bytes;
|
||||
|
||||
// Build the TRT op
|
||||
// TODO(sami,ben,jie): proper naming!
|
||||
tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp");
|
||||
TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes));
|
||||
|
||||
std::vector<string> segment_names;
|
||||
segment_names.reserve(s.subgraph_node_ids.size());
|
||||
for (int i : s.subgraph_node_ids) {
|
||||
auto node = s.graph.FindNodeId(i);
|
||||
segment_names.push_back(node->name());
|
||||
}
|
||||
LOG(INFO) << "finished op preparation";
|
||||
|
||||
auto status = op_builder.Attr("segment_nodes", segment_names)
|
||||
.Attr("input_names", input_names)
|
||||
.Attr("segment_output_names", output_names)
|
||||
.Attr("resource_name", calib_op_name)
|
||||
.Finalize(s.trt_node);
|
||||
|
||||
LOG(INFO) << status.ToString();
|
||||
LOG(INFO) << "finished op building";
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
|
||||
tensorrt::convert::SubGraphParams& s) {
|
||||
// Visit nodes in reverse topological order and construct the TRT network.
|
||||
std::list<tensorflow::Node*> order;
|
||||
TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order));
|
||||
|
||||
static int static_id = 0;
|
||||
string subgraph_name_scope = SubgraphNameScopeGenerator(&order);
|
||||
string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id++);
|
||||
|
||||
tensorflow::tensorrt::Logger trt_logger;
|
||||
cudaSetDevice(s.cuda_gpu_id_);
|
||||
auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
|
||||
if (!trt_builder) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Failed to create TensorRT builder object");
|
||||
}
|
||||
#if NV_TENSORRT_MAJOR > 3
|
||||
trt_builder->setGpuAllocator(s.allocator_.get());
|
||||
#endif
|
||||
auto trt_network = infer_object(trt_builder->createNetwork());
|
||||
// Create the network.
|
||||
auto trt_network =
|
||||
TrtUniquePtrType<nvinfer1::INetworkDefinition>(builder->createNetwork());
|
||||
if (!trt_network) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Failed to create TensorRT network object");
|
||||
}
|
||||
|
||||
auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance();
|
||||
auto weight_rmgr = trt_rmgr->getManager("WeightStore");
|
||||
auto ws = new tensorflow::tensorrt::TRTWeightStore();
|
||||
TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws));
|
||||
auto ws = std::unique_ptr<TRTWeightStore>(new TRTWeightStore());
|
||||
|
||||
// Build the network
|
||||
Converter converter(trt_network.get(), ws, s.precision_mode == FP16MODE);
|
||||
VLOG(1) << "Starting engine conversion ";
|
||||
Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE);
|
||||
std::vector<std::pair<string, string>> output_tensors;
|
||||
// Graph nodes are already topologically sorted during construction
|
||||
for (const auto& node_def : gdef.node()) {
|
||||
string node_name = node_def.name();
|
||||
VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op();
|
||||
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
|
||||
(node_def.op() == "Placeholder")) {
|
||||
nvinfer1::DimsCHW input_dim_pseudo_chw;
|
||||
for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
|
||||
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
|
||||
auto type_status =
|
||||
ConvertDType(node_def.attr().at("dtype").type(), &dtype);
|
||||
if (type_status != tensorflow::Status::OK()) {
|
||||
LOG(WARNING) << "Type conversion failed for " << node_name;
|
||||
return type_status;
|
||||
}
|
||||
int32 slot_number = -1;
|
||||
if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8,
|
||||
&slot_number)) {
|
||||
LOG(ERROR) << "Failed to parse slot number from " << node_name
|
||||
<< " +8= " << node_name.c_str() + 8;
|
||||
}
|
||||
auto shape = input_shapes.at(slot_number);
|
||||
if (shape.dims() > 8) {
|
||||
LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name
|
||||
<< " at input slot " << slot_number;
|
||||
return tensorflow::errors::OutOfRange(
|
||||
"Input tensor rank is greater than 8");
|
||||
}
|
||||
if (VLOG_IS_ON(1)) {
|
||||
string dim_str("dims=");
|
||||
StrAppend(&dim_str, "[ ", shape.dim_size(0));
|
||||
for (int i = 1; i < shape.dims(); i++) {
|
||||
StrAppend(&dim_str, ", ", shape.dim_size(i));
|
||||
}
|
||||
StrAppend(&dim_str, " ]");
|
||||
VLOG(1) << dim_str;
|
||||
}
|
||||
for (int i = 1; i < shape.dims(); i++) {
|
||||
input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i);
|
||||
}
|
||||
|
||||
std::vector<string> input_names;
|
||||
std::vector<tensorflow::DataType> input_dtypes;
|
||||
std::vector<string> output_names;
|
||||
std::vector<tensorflow::DataType> output_dtypes;
|
||||
TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names,
|
||||
&input_dtypes, &output_names,
|
||||
&output_dtypes, engine_name));
|
||||
|
||||
VLOG(2) << "Finished output";
|
||||
|
||||
// Build the engine
|
||||
trt_builder->setMaxBatchSize(s.max_batch_size);
|
||||
trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes);
|
||||
VLOG(0) << "Max batch size= " << s.max_batch_size
|
||||
<< " max workspace size= " << s.max_workspace_size_bytes;
|
||||
if (s.precision_mode == FP16MODE) {
|
||||
trt_builder->setHalf2Mode(true);
|
||||
VLOG(0) << "Using FP16 precision mode";
|
||||
}
|
||||
LOG(INFO) << "starting build engine";
|
||||
string engine_plan_string;
|
||||
{
|
||||
auto trt_engine =
|
||||
infer_object(trt_builder->buildCudaEngine(*converter.network()));
|
||||
VLOG(0) << "Built network";
|
||||
if (trt_engine.get() == nullptr) {
|
||||
return tensorflow::errors::Internal("Engine building failure");
|
||||
input_dim_pseudo_chw.nbDims = shape.dims() - 1;
|
||||
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
|
||||
node_name.c_str(), dtype, input_dim_pseudo_chw);
|
||||
if (!input_tensor) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Failed to create Input layer tensor ", node_name,
|
||||
" rank=", shape.dims() - 1);
|
||||
}
|
||||
VLOG(1) << "Input tensor name :" << node_name;
|
||||
if (!converter.insert_input_tensor(node_name, input_tensor)) {
|
||||
return tensorflow::errors::AlreadyExists(
|
||||
"Output tensor already exists for op: " + node_name);
|
||||
}
|
||||
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
|
||||
(node_def.op() == "Identity")) {
|
||||
int32 slot_number = -1;
|
||||
if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9,
|
||||
&slot_number)) {
|
||||
LOG(ERROR) << "Failed to parse slot number from " << node_name
|
||||
<< " +9=" << node_name.c_str() + 9;
|
||||
}
|
||||
if (output_tensors.size() <= slot_number) {
|
||||
output_tensors.resize(slot_number + 1);
|
||||
}
|
||||
output_tensors.at(slot_number) = {node_def.input(0), node_name};
|
||||
} else {
|
||||
VLOG(2) << "Converting node: " << node_def.name() << " , "
|
||||
<< node_def.op();
|
||||
TF_RETURN_IF_ERROR(converter.convert_node(node_def));
|
||||
}
|
||||
auto engine_plan = infer_object(trt_engine->serialize());
|
||||
VLOG(0) << "Serialized engine";
|
||||
const char* engine_plan_data =
|
||||
static_cast<const char*>(engine_plan->data());
|
||||
engine_plan_string =
|
||||
string(engine_plan_data, engine_plan_data + engine_plan->size());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(
|
||||
engine_name, engine_name));
|
||||
LOG(INFO) << "finished engine " << engine_name << " containing "
|
||||
<< s.subgraph_node_ids.size() << " nodes";
|
||||
for (const auto& output : output_tensors) {
|
||||
auto tensor_or_weights = converter.get_tensor(output.first);
|
||||
if (!tensor_or_weights.is_tensor()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Output node '" + output.first + "' is weights not tensor");
|
||||
}
|
||||
nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
|
||||
tensor->setName(output.second.c_str());
|
||||
if (!tensor) {
|
||||
return tensorflow::errors::NotFound("Output tensor not found: " +
|
||||
output.first);
|
||||
}
|
||||
VLOG(1) << "Marking output tensor " << output.first << ", as output tensor "
|
||||
<< output.second;
|
||||
|
||||
// Build the TRT op
|
||||
tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
|
||||
TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes));
|
||||
converter.network()->markOutput(*tensor);
|
||||
}
|
||||
if (convert_successfully) *convert_successfully = true;
|
||||
|
||||
VLOG(0) << "Finished op preparation";
|
||||
// Build the engine.
|
||||
VLOG(1) << "Starting engine creation";
|
||||
engine->reset(builder->buildCudaEngine(*converter.network()));
|
||||
if (engine->get() == nullptr) {
|
||||
return tensorflow::errors::Internal("Failed to build TensorRT engine");
|
||||
}
|
||||
VLOG(1) << "Finished conversion";
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
auto status = op_builder.Attr("serialized_engine", engine_plan_string)
|
||||
.Attr("input_nodes", input_names)
|
||||
.Attr("output_nodes", output_names)
|
||||
.Attr("OutT", output_dtypes)
|
||||
.Device(s.device_name_)
|
||||
.Finalize(s.trt_node);
|
||||
tensorflow::Status ConvertSegmentToGraphDef(
|
||||
const tensorflow::Graph* graph,
|
||||
const tensorflow::grappler::GraphProperties& graph_properties,
|
||||
const std::vector<int>& subgraph_node_ids, // In topological order
|
||||
std::vector<EngineConnection>* connections,
|
||||
tensorflow::GraphDef* segment_def, string* common_scope) {
|
||||
std::set<string> marker_nodes;
|
||||
// Update connection shapes/data types and add corresponding input/output
|
||||
// nodes in the segment graphdef.
|
||||
for (size_t i = 0; i < connections->size(); ++i) {
|
||||
auto& connection = connections->at(i);
|
||||
auto outside_node = graph->FindNodeId(connection.outside_id);
|
||||
if (!outside_node) {
|
||||
// This should never happen, unless the original graph is problematic.
|
||||
return tensorflow::errors::NotFound(
|
||||
"Cannot find node with id ", connection.outside_id, " in the graph.");
|
||||
}
|
||||
// Updates the shape and data types of input/output connections.
|
||||
tensorflow::DataType input_type = tensorflow::DT_FLOAT;
|
||||
tensorflow::PartialTensorShape partial_shape;
|
||||
if (connection.is_input_edge) {
|
||||
if (graph_properties.HasOutputProperties(connection.outside_node_name)) {
|
||||
auto output_params =
|
||||
graph_properties.GetOutputProperties(connection.outside_node_name);
|
||||
auto out_shape = output_params.at(connection.outside_port);
|
||||
input_type = out_shape.dtype();
|
||||
std::vector<tensorflow::int64> dims;
|
||||
partial_shape = out_shape.shape();
|
||||
connection.outside_shape = partial_shape;
|
||||
} else {
|
||||
VLOG(0) << "Unknown output shape" << outside_node->name();
|
||||
input_type = graph->FindNodeId(connection.outside_id)
|
||||
->output_type(connection.outside_port);
|
||||
}
|
||||
connection.connection_type = input_type;
|
||||
|
||||
VLOG(0) << status.ToString() << " finished op building for " << engine_name
|
||||
<< " on device " << s.device_name_;
|
||||
} else { // output edge
|
||||
if (graph_properties.HasInputProperties(connection.outside_node_name)) {
|
||||
auto input_params =
|
||||
graph_properties.GetInputProperties(connection.outside_node_name);
|
||||
auto in_shape = input_params.at(connection.outside_port);
|
||||
input_type = in_shape.dtype();
|
||||
partial_shape = in_shape.shape();
|
||||
connection.inside_shape = partial_shape;
|
||||
} else {
|
||||
input_type = graph->FindNodeId(connection.inside_id)
|
||||
->output_type(connection.outside_port);
|
||||
}
|
||||
connection.connection_type = input_type;
|
||||
}
|
||||
|
||||
// Add dummy input/output nodes to the segment graphdef.
|
||||
if (connection.is_input_edge) {
|
||||
const string node_name = StrCat(kInputPHName, connection.port_number);
|
||||
if (marker_nodes.count(node_name)) {
|
||||
VLOG(1) << "Reusing input " << node_name << " for the edge "
|
||||
<< connection.outside_node_name << ":"
|
||||
<< connection.outside_port << " -> "
|
||||
<< connection.inside_node_name << ":" << connection.inside_port;
|
||||
continue;
|
||||
}
|
||||
marker_nodes.insert(node_name);
|
||||
auto seg_node = segment_def->add_node();
|
||||
tensorflow::NodeDefBuilder builder(node_name, "Placeholder");
|
||||
auto status = builder.Attr("shape", partial_shape)
|
||||
.Attr("dtype", input_type)
|
||||
.Finalize(seg_node);
|
||||
VLOG(1) << "Constructing input " << node_name << " for the edge "
|
||||
<< connection.outside_node_name << ":" << connection.outside_port
|
||||
<< " -> " << connection.inside_node_name << ":"
|
||||
<< connection.inside_port;
|
||||
} else {
|
||||
const string node_name = StrCat(kOutputPHName, connection.port_number);
|
||||
if (marker_nodes.count(node_name)) {
|
||||
VLOG(1) << "Reusing output " << node_name << " for the edge "
|
||||
<< connection.inside_node_name << ":" << connection.inside_port
|
||||
<< " -> " << connection.outside_node_name << ":"
|
||||
<< connection.outside_port;
|
||||
continue;
|
||||
}
|
||||
marker_nodes.insert(node_name);
|
||||
auto seg_node = segment_def->add_node();
|
||||
tensorflow::NodeDefBuilder builder(node_name, "Identity");
|
||||
auto status = builder.Input(connection.inside_node_name, 0, input_type)
|
||||
.Finalize(seg_node);
|
||||
VLOG(1) << "Constructing output " << node_name << " for the edge "
|
||||
<< connection.inside_node_name << ":" << connection.inside_port
|
||||
<< " -> " << connection.outside_node_name << ":"
|
||||
<< connection.outside_port;
|
||||
}
|
||||
} // for each connection.
|
||||
|
||||
std::unordered_map<int, int> old_to_new_id_map;
|
||||
// Copy internal nodes to new graphdef
|
||||
string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name();
|
||||
for (const auto node_id : subgraph_node_ids) {
|
||||
const auto node = graph->FindNodeId(node_id);
|
||||
local_scope = GetCommonNameScope(local_scope, node->name());
|
||||
old_to_new_id_map[node_id] = segment_def->node_size();
|
||||
auto snode = segment_def->add_node();
|
||||
snode->CopyFrom(node->def());
|
||||
VLOG(1) << "Copying " << snode->name() << " to subgraph";
|
||||
}
|
||||
// Update the inputs of the new input nodes to point to placeholder nodes.
|
||||
for (int i = 0; i < connections->size(); ++i) {
|
||||
auto& connection = connections->at(i);
|
||||
if (!connection.is_input_edge) continue;
|
||||
auto snode =
|
||||
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
|
||||
const string placeholder_name =
|
||||
StrCat(kInputPHName, connection.port_number);
|
||||
VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
|
||||
<< " from " << snode->input(connection.inside_port) << " to "
|
||||
<< placeholder_name;
|
||||
snode->set_input(connection.inside_port, placeholder_name);
|
||||
}
|
||||
*common_scope = local_scope;
|
||||
VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph";
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
|
@ -22,69 +22,112 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/utils.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
static const char* kInputPHName = "InputPH_";
|
||||
static const char* kOutputPHName = "OutputPH_";
|
||||
namespace convert {
|
||||
|
||||
// TODO(aaroey): use an enum instead.
|
||||
const int FP32MODE = 0;
|
||||
const int FP16MODE = 1;
|
||||
const int INT8MODE = 2;
|
||||
|
||||
struct SubGraphParams {
|
||||
SubGraphParams(
|
||||
tensorflow::Graph& inp_graph,
|
||||
const std::set<int>& subgraph_node_id_numbers,
|
||||
const std::vector<std::pair<int, int>>& input_indices,
|
||||
const std::vector<std::pair<int, int>>& output_indices,
|
||||
size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
|
||||
const tensorflow::grappler::GraphProperties& current_graph_properties,
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edges,
|
||||
tensorflow::NodeDef* constructed_trt_node,
|
||||
int engine_precision_mode = FP32MODE, const string& device_name = "",
|
||||
std::shared_ptr<nvinfer1::IGpuAllocator> allocator = nullptr,
|
||||
int cuda_gpu_id = 0)
|
||||
: graph(inp_graph),
|
||||
subgraph_node_ids(subgraph_node_id_numbers),
|
||||
input_inds(input_indices),
|
||||
output_inds(output_indices),
|
||||
max_batch_size(max_supported_batch_size),
|
||||
max_workspace_size_bytes(max_consumed_workspace_size_bytes),
|
||||
graph_properties(current_graph_properties),
|
||||
output_edge_map(output_edges),
|
||||
trt_node(constructed_trt_node),
|
||||
precision_mode(engine_precision_mode),
|
||||
device_name_(device_name),
|
||||
allocator_(allocator),
|
||||
cuda_gpu_id_(cuda_gpu_id) {}
|
||||
struct EngineConnection {
|
||||
EngineConnection(const string& outside, int out_id, int out_port,
|
||||
const string& inside, int in_id, int in_port,
|
||||
bool input_edge, int port)
|
||||
: outside_node_name(outside),
|
||||
outside_id(out_id),
|
||||
outside_port(out_port),
|
||||
inside_node_name(inside),
|
||||
inside_id(in_id),
|
||||
inside_port(in_port),
|
||||
is_input_edge(input_edge),
|
||||
port_number(port) {}
|
||||
|
||||
tensorflow::Graph& graph;
|
||||
const std::set<int>& subgraph_node_ids;
|
||||
const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx}
|
||||
const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx}
|
||||
size_t max_batch_size;
|
||||
size_t max_workspace_size_bytes;
|
||||
const tensorflow::grappler::GraphProperties& graph_properties;
|
||||
std::unordered_map<string, std::pair<int, string>>* output_edge_map;
|
||||
tensorflow::NodeDef* trt_node;
|
||||
const int precision_mode;
|
||||
const string device_name_;
|
||||
std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
|
||||
const int cuda_gpu_id_;
|
||||
const string outside_node_name;
|
||||
const int outside_id;
|
||||
const int outside_port;
|
||||
tensorflow::PartialTensorShape outside_shape;
|
||||
|
||||
const string inside_node_name;
|
||||
const int inside_id;
|
||||
const int inside_port;
|
||||
tensorflow::PartialTensorShape inside_shape;
|
||||
|
||||
tensorflow::DataType connection_type;
|
||||
bool is_input_edge;
|
||||
|
||||
// The port number of the TRT node connecting to this edge.
|
||||
int port_number;
|
||||
};
|
||||
|
||||
// TODO(sami): Replace references with const reference or pointers
|
||||
tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params);
|
||||
tensorflow::Status InjectCalibrationNode(SubGraphParams& params);
|
||||
tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph,
|
||||
tensorflow::Node* c_node);
|
||||
struct EngineInfo {
|
||||
EngineInfo()
|
||||
: engine_type(EngineType::TRTStatic),
|
||||
max_workspace_size_bytes(0),
|
||||
precision_mode(FP32MODE) {}
|
||||
|
||||
string engine_name;
|
||||
string device;
|
||||
tensorflow::GraphDef segment_graph_def;
|
||||
|
||||
// The segment nodes that are on one side of the edges are topological sorted.
|
||||
std::vector<EngineConnection> connections;
|
||||
|
||||
enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
|
||||
EngineType engine_type;
|
||||
int64 max_workspace_size_bytes;
|
||||
int maximum_cached_engines;
|
||||
std::vector<int> cached_engine_batches;
|
||||
int precision_mode;
|
||||
};
|
||||
|
||||
// Constructs a graphdef from the segment in the given graph. Adds placeholder
|
||||
// nodes for input edges (InputPH_*) and identity nodes for output edges
|
||||
// (OutputPH_*). This function needs to be called before TensorRT nodes
|
||||
// inserted in order to correctly get sizes from the original graph.
|
||||
//
|
||||
// - subgraph_node_ids: the node ids of the subgraph, must be sorted in
|
||||
// topological order.
|
||||
// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
|
||||
// sorted in topological order.
|
||||
tensorflow::Status ConvertSegmentToGraphDef(
|
||||
const tensorflow::Graph* graph,
|
||||
const tensorflow::grappler::GraphProperties& graph_properties,
|
||||
const std::vector<int>& subgraph_node_ids,
|
||||
std::vector<EngineConnection>* connections,
|
||||
tensorflow::GraphDef* segment_def, string* common_scope);
|
||||
|
||||
// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
|
||||
// 'builder' successfully build the engine. If the result is not ok, 'engine'
|
||||
// will be set to nullptr
|
||||
// Once returned, 'builder' is not needed any more and can be safely detroyed.
|
||||
//
|
||||
// - convert_successfully: indicates whether the converson to TensorRT network
|
||||
// is successful. This is different than successfully building the engine:
|
||||
// building can still fail afterwards.
|
||||
tensorflow::Status ConvertGraphDefToEngine(
|
||||
const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size,
|
||||
size_t max_workspace_size_bytes,
|
||||
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
|
||||
Logger* logger, nvinfer1::IGpuAllocator* allocator,
|
||||
TRTInt8Calibrator* calibrator,
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
|
||||
bool* convert_successfully);
|
||||
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -45,8 +45,24 @@ tensorflow::Status TRTOptimizationPass::Init(
|
||||
if (params.count("max_batch_size")) {
|
||||
maximum_batch_size_ = params.at("max_batch_size").i();
|
||||
}
|
||||
if (params.count("max_workspace_size_bytes"))
|
||||
is_dynamic_op_ = false;
|
||||
if (params.count("is_dynamic_op")) {
|
||||
is_dynamic_op_ = params.at("is_dynamic_op").b();
|
||||
}
|
||||
if (params.count("cached_engine_batches")) {
|
||||
auto batch_vec = params.at("cached_engine_batches").list();
|
||||
batches_.reserve(batch_vec.i_size());
|
||||
for (const auto i : batch_vec.i()) {
|
||||
batches_.push_back(i);
|
||||
}
|
||||
}
|
||||
max_cached_batches_ = 1;
|
||||
if (params.count("maximum_cached_engines")) {
|
||||
max_cached_batches_ = params.at("maximum_cached_engines").i();
|
||||
}
|
||||
if (params.count("max_workspace_size_bytes")) {
|
||||
maximum_workspace_size_ = params.at("max_workspace_size_bytes").i();
|
||||
}
|
||||
if (params.count("precision_mode")) {
|
||||
string pm = Uppercase(params.at("precision_mode").s());
|
||||
if (pm == "FP32") {
|
||||
@ -175,6 +191,17 @@ tensorflow::Status TRTOptimizationPass::Optimize(
|
||||
if (VLOG_IS_ON(1)) {
|
||||
PrintDebugInfo(cluster, item);
|
||||
}
|
||||
// This is a hack to workaround optimizer issue. MetaOptimizer calls
|
||||
// optimization passes on function objects as well, we should not modify
|
||||
// generated funcdefs! This is fragile but we don't have any other option
|
||||
// until framework fixes it.
|
||||
if (item.id != "tf_graph") {
|
||||
LOG(WARNING) << name_
|
||||
<< " is probably called on funcdef! This optimizer must *NOT* "
|
||||
"be called on function objects.";
|
||||
*optimized_graph = item.graph;
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
int max_dim = -1;
|
||||
if (item.feed.size()) {
|
||||
for (const auto& f : item.feed) {
|
||||
@ -204,11 +231,22 @@ tensorflow::Status TRTOptimizationPass::Optimize(
|
||||
}
|
||||
tensorflow::grappler::GraphProperties static_graph_properties(item);
|
||||
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
|
||||
auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(
|
||||
item.graph, item.fetch, maximum_batch_size_, maximum_workspace_size_,
|
||||
optimized_graph, precision_mode_, minimum_segment_size_,
|
||||
static_graph_properties, cluster);
|
||||
tensorflow::tensorrt::convert::ConversionParams cp;
|
||||
cp.input_graph_def = &item.graph;
|
||||
cp.output_names = &item.fetch;
|
||||
cp.max_batch_size = maximum_batch_size_;
|
||||
cp.max_workspace_size_bytes = maximum_workspace_size_;
|
||||
cp.output_graph_def = optimized_graph;
|
||||
cp.precision_mode = precision_mode_;
|
||||
cp.minimum_segment_size = minimum_segment_size_;
|
||||
cp.graph_properties = &static_graph_properties;
|
||||
cp.cluster = cluster;
|
||||
cp.is_dyn_op = is_dynamic_op_;
|
||||
cp.cached_engine_batches = batches_;
|
||||
cp.max_cached_engines = max_cached_batches_;
|
||||
auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp);
|
||||
VLOG(2) << optimized_graph->DebugString();
|
||||
VLOG(1) << "Returning from " << name_;
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -61,6 +61,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
|
||||
int minimum_segment_size_;
|
||||
int precision_mode_;
|
||||
int maximum_batch_size_;
|
||||
bool is_dynamic_op_;
|
||||
std::vector<int> batches_;
|
||||
int max_cached_batches_;
|
||||
int64_t maximum_workspace_size_;
|
||||
};
|
||||
|
||||
|
37
tensorflow/contrib/tensorrt/convert/utils.h
Normal file
37
tensorflow/contrib/tensorrt/convert/utils.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
template <typename T>
|
||||
struct TrtDestroyer {
|
||||
void operator()(T* t) {
|
||||
if (t) t->destroy();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
|
@ -14,8 +14,16 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/contrib/tensorrt/convert/utils.h"
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -25,144 +33,556 @@ limitations under the License.
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
static ::tensorflow::tensorrt::Logger logger;
|
||||
using IRuntime = nvinfer1::IRuntime;
|
||||
using Dims = nvinfer1::Dims;
|
||||
|
||||
namespace tensorrt {
|
||||
static Logger logger;
|
||||
using ::nvinfer1::IRuntime;
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
// read serialized_engine
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("serialized_engine", &serialized_engine_));
|
||||
// A helper class to call done() when destructed for asynchronous execution.
|
||||
// Helps simultaneous execution of native and TRT engines.
|
||||
class AsyncHelper : public tensorflow::core::RefCounted {
|
||||
public:
|
||||
AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; }
|
||||
~AsyncHelper() override { done_(); }
|
||||
|
||||
// register input output node name in trt_sub_graph
|
||||
OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
|
||||
private:
|
||||
tensorflow::AsyncOpKernel::DoneCallback done_;
|
||||
};
|
||||
|
||||
#define TYPECASE(dt, X, Y) \
|
||||
case dt: { \
|
||||
return (void*)X->flat<tensorflow::EnumToDataType<dt>::Type>().data(); \
|
||||
}
|
||||
|
||||
void* GetTensorAddress(const Tensor* tensor_ptr) {
|
||||
auto tensor_type = tensor_ptr->dtype();
|
||||
switch (tensor_type) {
|
||||
TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr);
|
||||
TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr);
|
||||
TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr);
|
||||
default: {
|
||||
LOG(ERROR) << "Unsupported Data type "
|
||||
<< tensorflow::DataTypeString(tensor_type);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TRTEngineOp::Compute(OpKernelContext* context) {
|
||||
// TODO(samikama) runtime should be taken from a resourcemanager as well.
|
||||
// Only engine should be in the op and context and runtime should be taken
|
||||
// from resourcemanager
|
||||
|
||||
if (!trt_execution_context_ptr_) {
|
||||
IRuntime* infer = nvinfer1::createInferRuntime(logger);
|
||||
#if NV_TENSORRT_MAJOR > 3
|
||||
auto device = context->device();
|
||||
auto dev_allocator =
|
||||
device->GetAllocator(tensorflow::AllocatorAttributes());
|
||||
if (!dev_allocator) {
|
||||
LOG(FATAL) << "Can't find device allocator for gpu device "
|
||||
<< device->name();
|
||||
}
|
||||
allocator_ = std::make_shared<TRTDeviceAllocator>(dev_allocator);
|
||||
infer->setGpuAllocator(allocator_.get());
|
||||
#endif
|
||||
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
|
||||
serialized_engine_.c_str(), serialized_engine_.size(),
|
||||
PluginFactoryTensorRT::GetInstance()));
|
||||
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
|
||||
// Runtime is safe to delete after engine creation
|
||||
infer->destroy();
|
||||
serialized_engine_.clear();
|
||||
tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
|
||||
VLOG(1) << "Constructing function handle";
|
||||
auto lib = ctx->function_library();
|
||||
if (lib == nullptr) {
|
||||
return tensorflow::errors::Internal("Context function library is null");
|
||||
}
|
||||
int num_binding = context->num_inputs() + context->num_outputs();
|
||||
std::vector<void*> buffers(num_binding);
|
||||
auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
|
||||
if (fdef == nullptr) {
|
||||
return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_,
|
||||
" can't be found in function library");
|
||||
}
|
||||
tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops;
|
||||
inst_ops.overlay_lib = nullptr;
|
||||
inst_ops.state_handle = "";
|
||||
inst_ops.target = ctx->device()->name();
|
||||
native_func_ = 0;
|
||||
auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()),
|
||||
inst_ops, &native_func_);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << " Instantiating native function " << funcdef_name_
|
||||
<< " failed!";
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
size_t binding_index;
|
||||
int num_batch = 0;
|
||||
for (int i = 0; i < context->num_inputs(); i++) {
|
||||
// Grab the input tensor
|
||||
binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
|
||||
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
: AsyncOpKernel(context) {
|
||||
// read serialized_engine
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("serialized_segment", &serialized_segment_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("workspace_size_bytes", &workspace_size_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
|
||||
if (!static_engine_) {
|
||||
if (!segment_graph_.ParseFromString(serialized_segment_)) {
|
||||
LOG(ERROR) << "Parsing segment graph failed!";
|
||||
context->SetStatus(tensorflow::errors::InvalidArgument(
|
||||
"Failed to parse segment graphdef!"));
|
||||
return;
|
||||
}
|
||||
serialized_segment_.resize(0);
|
||||
}
|
||||
VLOG(1) << "Constructing " << name();
|
||||
string precision_string;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("precision_mode", &precision_string));
|
||||
string calibration_data;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("calibration_data", &calibration_data));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("segment_funcdef_name", &funcdef_name_));
|
||||
if (precision_string == "FP32") {
|
||||
precision_mode_ = convert::FP32MODE;
|
||||
} else if (precision_string == "FP16") {
|
||||
precision_mode_ = convert::FP16MODE;
|
||||
} else if (precision_string == "INT8") {
|
||||
precision_mode_ = convert::INT8MODE;
|
||||
}
|
||||
calibration_mode_ =
|
||||
(precision_mode_ == convert::INT8MODE && calibration_data.size() == 0);
|
||||
if (calibration_data.size()) {
|
||||
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
|
||||
calibration_data.resize(0);
|
||||
}
|
||||
native_func_ = tensorflow::kInvalidHandle;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
|
||||
&max_cached_engines_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("fixed_input_size", &fixed_input_size_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
|
||||
&cached_engine_batches_));
|
||||
std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
|
||||
if (VLOG_IS_ON(1)) {
|
||||
string s("Engine Batches= ");
|
||||
for (auto i : cached_engine_batches_) {
|
||||
StrAppend(&s, i, " ");
|
||||
}
|
||||
VLOG(1) << s;
|
||||
}
|
||||
}
|
||||
|
||||
const Tensor& input_tensor = context->input(i);
|
||||
const TensorShape& input_shape = input_tensor.shape();
|
||||
if (i == 0) {
|
||||
num_batch = input_shape.dim_size(0);
|
||||
if (num_batch > trt_engine_ptr_->getMaxBatchSize()) {
|
||||
LOG(FATAL) << "input tensor batch larger than max_batch_size: "
|
||||
<< trt_engine_ptr_->getMaxBatchSize();
|
||||
}
|
||||
} else if (num_batch != input_shape.dim_size(0)) {
|
||||
LOG(FATAL) << "input data inconsistent batch size";
|
||||
void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
|
||||
AsyncHelper* helper) {
|
||||
if (!calibration_mode_) {
|
||||
VLOG(1) << "Executing native engine";
|
||||
}
|
||||
std::vector<Tensor> inputs;
|
||||
std::vector<Tensor>* outputs = new std::vector<Tensor>();
|
||||
if (native_func_ == tensorflow::kInvalidHandle) {
|
||||
auto status = ConstructFunctionHandle(ctx);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_;
|
||||
ctx->SetStatus(status);
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto lib = ctx->function_library();
|
||||
tensorflow::FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
opts.runner = ctx->runner();
|
||||
for (int i = 0; i < ctx->num_inputs(); i++) {
|
||||
inputs.push_back(ctx->input(i));
|
||||
}
|
||||
helper->Ref(); // Increment count for calculating native graph
|
||||
VLOG(1) << "Executing native segment " << name();
|
||||
lib->Run(opts, native_func_, inputs, outputs,
|
||||
[ctx, outputs, helper](const tensorflow::Status& s) {
|
||||
tensorflow::core::ScopedUnref sc(helper);
|
||||
VLOG(1) << "Native Segment completed";
|
||||
if (!s.ok()) {
|
||||
ctx->SetStatus(s);
|
||||
return;
|
||||
}
|
||||
for (size_t t = 0; t < outputs->size(); ++t) {
|
||||
ctx->set_output(t, outputs->at(t));
|
||||
}
|
||||
delete outputs;
|
||||
});
|
||||
}
|
||||
|
||||
void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
|
||||
AsyncHelper* helper) {
|
||||
helper->Ref();
|
||||
tensorflow::core::ScopedUnref sc(helper);
|
||||
// TODO(aaroey): remove the ResourceMgr singleton.
|
||||
auto trt_rm = TRTResourceManager::instance();
|
||||
auto res_mgr = trt_rm->getManager("TRTCalibration");
|
||||
TRTCalibrationResource* calib_res = nullptr;
|
||||
auto status = res_mgr->LookupOrCreate(
|
||||
funcdef_name_, "Calibrator", &calib_res,
|
||||
{[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status {
|
||||
return this->AllocateCalibrationResources(ctx, cr);
|
||||
}});
|
||||
if (!status.ok()) {
|
||||
ctx->SetStatus(status);
|
||||
return;
|
||||
}
|
||||
int num_inputs = ctx->num_inputs();
|
||||
// Pass input data to calibrator
|
||||
std::unordered_map<string, void*> input_data;
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
const Tensor& t = ctx->input(i);
|
||||
void* data_address = GetTensorAddress(&t);
|
||||
if (data_address == nullptr) {
|
||||
ctx->SetStatus(tensorflow::errors::InvalidArgument(
|
||||
"Unsupported data type encountered in input ", i));
|
||||
return;
|
||||
}
|
||||
// Check the allocated buffer is sufficient for input
|
||||
const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx);
|
||||
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
|
||||
input_data.emplace(StrCat(kInputPHName, i), data_address);
|
||||
}
|
||||
VLOG(2) << "Filled map for sending";
|
||||
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
||||
const cudaStream_t* stream = CHECK_NOTNULL(
|
||||
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
calib_res->calibrator_->setBatch(input_data, *stream);
|
||||
VLOG(2) << "Passed calibration data";
|
||||
ExecuteNativeSegment(ctx, helper);
|
||||
}
|
||||
|
||||
int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) {
|
||||
int num_batch = ctx->input(0).shape().dim_size(0);
|
||||
int smallest_engine = 0;
|
||||
for (const auto i : cached_engine_batches_) {
|
||||
if (i >= num_batch) {
|
||||
smallest_engine = i;
|
||||
break;
|
||||
}
|
||||
auto dtype = trt_engine_ptr_->getBindingDataType(binding_index);
|
||||
}
|
||||
// TODO(sami): Need an LRU here
|
||||
if (smallest_engine == 0) {
|
||||
if (max_cached_engines_ > cached_engine_batches_.size()) {
|
||||
smallest_engine = num_batch;
|
||||
cached_engine_batches_.push_back(num_batch);
|
||||
VLOG(1) << "Running with batch size " << num_batch;
|
||||
} else {
|
||||
string s("Engine buffer is full. buffer limit= ");
|
||||
StrAppend(&s, max_cached_engines_, ", current entries= ");
|
||||
for (auto i : cached_engine_batches_) StrAppend(&s, i, ", ");
|
||||
StrAppend(&s, "Requested batch= ", num_batch);
|
||||
LOG(ERROR) << s;
|
||||
ctx->SetStatus(tensorflow::errors::ResourceExhausted(
|
||||
"Requested batch size is not available and engine cache is full"));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return smallest_engine;
|
||||
}
|
||||
|
||||
void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
|
||||
tensorflow::AsyncOpKernel::DoneCallback done) {
|
||||
auto helper = new AsyncHelper(done);
|
||||
tensorflow::core::ScopedUnref sc(helper);
|
||||
if (calibration_mode_) {
|
||||
ExecuteCalibration(ctx, helper);
|
||||
return;
|
||||
}
|
||||
const int smallest_engine = GetEngineBatch(ctx);
|
||||
if (smallest_engine < 0) return; // GetEngineBatch already set the status.
|
||||
|
||||
const int num_batch = ctx->input(0).shape().dim_size(0);
|
||||
auto& engine_ctx_pair = GetEngine(smallest_engine, ctx);
|
||||
auto& trt_engine_ptr = engine_ctx_pair.first;
|
||||
if (!trt_engine_ptr) {
|
||||
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
|
||||
<< " failed Running native segment";
|
||||
ExecuteNativeSegment(ctx, helper);
|
||||
return;
|
||||
}
|
||||
|
||||
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
|
||||
std::vector<void*> buffers(num_binding);
|
||||
for (int i = 0; i < ctx->num_inputs(); i++) {
|
||||
const string inp_name = StrCat(kInputPHName, i);
|
||||
const size_t binding_index =
|
||||
trt_engine_ptr->getBindingIndex(inp_name.c_str());
|
||||
|
||||
const Tensor& input_tensor = ctx->input(i);
|
||||
const TensorShape& input_shape = input_tensor.shape();
|
||||
if (num_batch != input_shape.dim_size(0)) {
|
||||
LOG(ERROR) << "input data inconsistent batch size";
|
||||
ctx->SetStatus(tensorflow::errors::FailedPrecondition(
|
||||
"Different batch sizes between input tensors"));
|
||||
return;
|
||||
}
|
||||
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
|
||||
switch (dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
LOG(FATAL) << "half size is not supported yet!";
|
||||
break;
|
||||
LOG(ERROR) << "FP16 inputs are not supported yet!";
|
||||
ctx->SetStatus(tensorflow::errors::InvalidArgument(
|
||||
"FP16 inputs are not supported!"));
|
||||
return;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
LOG(FATAL) << "int8 is not supported yet!";
|
||||
break;
|
||||
LOG(ERROR) << "INT8 inputs are not supported yet!";
|
||||
ctx->SetStatus(tensorflow::errors::InvalidArgument(
|
||||
"INT8 inputs are not supported!"));
|
||||
return;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type: " << int(dtype);
|
||||
break;
|
||||
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
|
||||
ctx->SetStatus(tensorflow::errors::InvalidArgument(
|
||||
"Unknown output TRT data type! ", static_cast<int>(dtype)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
|
||||
// This is bad that we have to reallocate output buffer every run.
|
||||
for (int i = 0; i < ctx->num_outputs(); i++) {
|
||||
// Create an output tensor
|
||||
binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
|
||||
const string output_name = StrCat(kOutputPHName, i);
|
||||
const size_t binding_index =
|
||||
trt_engine_ptr->getBindingIndex(output_name.c_str());
|
||||
Tensor* output_tensor = nullptr;
|
||||
|
||||
TensorShape output_shape;
|
||||
if (binding_index != -1) {
|
||||
auto dims = trt_engine_ptr_->getBindingDimensions(binding_index);
|
||||
auto dims = trt_engine_ptr->getBindingDimensions(binding_index);
|
||||
std::vector<int> trt_shape(dims.nbDims + 1);
|
||||
trt_shape[0] = num_batch;
|
||||
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
|
||||
OP_REQUIRES_OK(context,
|
||||
TensorShapeUtils::MakeShape(
|
||||
trt_shape.data(), trt_shape.size(), &output_shape));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
|
||||
&output_shape));
|
||||
} else {
|
||||
LOG(FATAL) << "output node not found, at " << output_nodes_[i];
|
||||
break;
|
||||
LOG(ERROR) << "output node not found, at " << output_name;
|
||||
ctx->SetStatus(tensorflow::errors::Internal("output ", output_name,
|
||||
" couldn't be found!"));
|
||||
return;
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(i, output_shape, &output_tensor));
|
||||
auto dtype = trt_engine_ptr_->getBindingDataType(binding_index);
|
||||
auto status = ctx->allocate_output(i, output_shape, &output_tensor);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Allocating output failed with " << status;
|
||||
ctx->SetStatus(status);
|
||||
return;
|
||||
}
|
||||
auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
|
||||
switch (dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
buffers[binding_index] =
|
||||
reinterpret_cast<void*>(output_tensor->flat<float>().data());
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
LOG(FATAL) << "half size is not supported yet!";
|
||||
break;
|
||||
LOG(ERROR) << "half size is not supported yet!";
|
||||
ctx->SetStatus(tensorflow::errors::InvalidArgument(
|
||||
"Half outputs are not supported!"));
|
||||
return;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
LOG(FATAL) << "int8 is not supported yet!";
|
||||
break;
|
||||
LOG(ERROR) << "int8 is not supported yet!";
|
||||
ctx->SetStatus(tensorflow::errors::InvalidArgument(
|
||||
"INT8 outputs are not supported!"));
|
||||
return;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type: " << int(dtype);
|
||||
break;
|
||||
LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
|
||||
ctx->SetStatus(tensorflow::errors::InvalidArgument(
|
||||
"Unsupported output data type! ", static_cast<int>(dtype)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
|
||||
const cudaStream_t* stream = CHECK_NOTNULL(
|
||||
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
|
||||
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->CudaStreamMemberHack()));
|
||||
|
||||
// TODO(jie): trt enqueue does not return error
|
||||
auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0],
|
||||
*stream, nullptr);
|
||||
VLOG(2) << "enqueue returns: " << ret;
|
||||
auto& trt_execution_context_ptr = engine_ctx_pair.second;
|
||||
auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream,
|
||||
nullptr);
|
||||
if (!ret) {
|
||||
LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name();
|
||||
ctx->SetStatus(tensorflow::errors::Internal(
|
||||
"Failed to enqueue batch for TRT engine: ", name()));
|
||||
}
|
||||
// sync should be done by TF.
|
||||
}
|
||||
|
||||
TRTEngineOp::~TRTEngineOp() {
|
||||
// Order matters!
|
||||
trt_execution_context_ptr_.reset();
|
||||
trt_engine_ptr_.reset();
|
||||
// We need to manually destroy the engine and execution context before
|
||||
// the allocator is destructed.
|
||||
for (auto& eng : engine_map_) {
|
||||
eng.second.first.reset();
|
||||
eng.second.second.reset();
|
||||
}
|
||||
allocator_.reset();
|
||||
}
|
||||
|
||||
nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) {
|
||||
if (allocator_) return allocator_.get();
|
||||
auto device = ctx->device();
|
||||
auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes());
|
||||
if (!alloc) {
|
||||
LOG(ERROR) << "Can't find device allocator for gpu device "
|
||||
<< device->name();
|
||||
ctx->SetStatus(tensorflow::errors::Internal(
|
||||
"Can't get device allocator for device ", device->name()));
|
||||
return nullptr;
|
||||
}
|
||||
allocator_.reset(new TRTDeviceAllocator(alloc));
|
||||
return allocator_.get();
|
||||
}
|
||||
|
||||
TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
|
||||
OpKernelContext* ctx) {
|
||||
static EngineCtxPair null_pair = {
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr),
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)};
|
||||
// TODO(sami): This method needs to be re-written to use resource manager and
|
||||
// with LRU mechanism option.
|
||||
tensorflow::mutex_lock lock(engine_mutex_);
|
||||
|
||||
if (static_engine_) {
|
||||
if (engine_map_.size()) {
|
||||
if (engine_map_.begin()->first >= batch_size) {
|
||||
return engine_map_.begin()->second;
|
||||
}
|
||||
return null_pair;
|
||||
}
|
||||
TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
|
||||
#if NV_TENSORRT_MAJOR > 3
|
||||
auto allocator = GetAllocator(ctx);
|
||||
if (allocator == nullptr) {
|
||||
// GetAllocator already set the Status.
|
||||
return null_pair;
|
||||
}
|
||||
infer->setGpuAllocator(allocator);
|
||||
#endif
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
|
||||
infer->deserializeCudaEngine(serialized_segment_.c_str(),
|
||||
serialized_segment_.size(), nullptr));
|
||||
auto raw_static_engine = static_engine.get();
|
||||
const auto max_batch_size = raw_static_engine->getMaxBatchSize();
|
||||
engine_map_[max_batch_size] = {
|
||||
std::move(static_engine),
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext>(
|
||||
raw_static_engine->createExecutionContext())};
|
||||
// Runtime is safe to delete after engine creation
|
||||
serialized_segment_.clear();
|
||||
if (max_batch_size < batch_size) return null_pair;
|
||||
return engine_map_.at(max_batch_size);
|
||||
} // static_engine_
|
||||
|
||||
// Handle the dynamic engine case.
|
||||
auto engine_it = engine_map_.find(batch_size);
|
||||
if (engine_it == engine_map_.end() &&
|
||||
engine_map_.size() < (size_t)max_cached_engines_) {
|
||||
nvinfer1::IGpuAllocator* allocator = nullptr;
|
||||
#if NV_TENSORRT_MAJOR > 3
|
||||
allocator = GetAllocator(ctx);
|
||||
if (allocator == nullptr) {
|
||||
// GetAllocator already set the Status.
|
||||
return null_pair;
|
||||
}
|
||||
#endif
|
||||
std::vector<tensorflow::PartialTensorShape> shapes;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
shapes.emplace_back(ctx->input(i).shape());
|
||||
}
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
|
||||
bool convert_successfully = false;
|
||||
VLOG(0) << name() << " Constructing a new engine with batch size "
|
||||
<< batch_size;
|
||||
// Up to this point, calibrator_ can never be empty, since otherwise it
|
||||
// means calibration_mode_ is true and this path won't get executed.
|
||||
auto status = convert::ConvertGraphDefToEngine(
|
||||
segment_graph_, precision_mode_, batch_size, workspace_size_, shapes,
|
||||
&logger, allocator, calibrator_.get(), &engine, &convert_successfully);
|
||||
if (!status.ok()) {
|
||||
if (convert_successfully) {
|
||||
// This means it fail to build the engine even when the network is built
|
||||
// successfully, probably due to internal issues. In this case we don't
|
||||
// retry in the future.
|
||||
engine_map_[batch_size] = {nullptr, nullptr};
|
||||
}
|
||||
LOG(ERROR) << "Engine creation for batch size " << batch_size
|
||||
<< " failed " << status;
|
||||
ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!"));
|
||||
return null_pair;
|
||||
}
|
||||
VLOG(1) << "Conversion is done";
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
engine->createExecutionContext());
|
||||
engine_map_[batch_size] = {std::move(engine), std::move(exec_context)};
|
||||
}
|
||||
return engine_map_.at(batch_size);
|
||||
}
|
||||
|
||||
tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
|
||||
tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) {
|
||||
auto cres = new TRTCalibrationResource();
|
||||
*cr = cres;
|
||||
// Get the allocator.
|
||||
auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes());
|
||||
if (!alloc) {
|
||||
LOG(WARNING) << "Can't get device allocator will not be able to "
|
||||
"allocate memory from TensorFlow memory pool";
|
||||
cres->allocator_.reset(new TRTCudaAllocator);
|
||||
} else {
|
||||
cres->allocator_.reset(new TRTDeviceAllocator(alloc));
|
||||
}
|
||||
// Get the input shapes.
|
||||
const int batch_size = ctx->input(0).dim_size(0);
|
||||
const int num_inputs = ctx->num_inputs();
|
||||
std::vector<tensorflow::PartialTensorShape> shapes;
|
||||
dev_tensors_.resize(num_inputs);
|
||||
VLOG(1) << " Constructing calibrator";
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
// allocate workspace on device for inputs
|
||||
const tensorflow::Tensor& t = ctx->input(i);
|
||||
shapes.emplace_back(t.shape());
|
||||
Tensor* device_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
|
||||
t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor));
|
||||
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
|
||||
void* device_address = GetTensorAddress(device_tensor);
|
||||
if (device_address == nullptr) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unsupported data type encountered in input ", i);
|
||||
}
|
||||
device_buffers_.emplace(
|
||||
StrCat(kInputPHName, i),
|
||||
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
|
||||
}
|
||||
cres->calibrator_.reset(
|
||||
new TRTInt8Calibrator(device_buffers_, batch_size, name()));
|
||||
const string label(name());
|
||||
auto segment_graph = &segment_graph_;
|
||||
const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
|
||||
if (cuda_gpu_id < 0) {
|
||||
LOG(ERROR) << "Can't get gpu_device_info from context->device()";
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Context->device doesn't contain device info!");
|
||||
}
|
||||
const int64 workspace_size_bytes = workspace_size_;
|
||||
cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
|
||||
cuda_gpu_id, workspace_size_bytes]() {
|
||||
VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
|
||||
<< ", Calibration Resource @ " << cres;
|
||||
auto err = cudaSetDevice(cuda_gpu_id);
|
||||
if (err != cudaSuccess) {
|
||||
// TODO(aaroey): should return error here.
|
||||
LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
|
||||
<< " in calibration thread";
|
||||
}
|
||||
// ConvertGraphDefToEngine() will try to build the engine. This thread
|
||||
// will loop inside buildCudaEngine() consuming the calibration data
|
||||
// that is set by the TF op, and drive the builder until calibrator returns
|
||||
// false. Engine is discarded after calibration table is generated
|
||||
//
|
||||
// TODO(aaroey): maybe setting the max batch size using the python
|
||||
// calibration wrapper class.
|
||||
auto s = convert::ConvertGraphDefToEngine(
|
||||
*segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(),
|
||||
workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
|
||||
cres->calibrator_.get(), &cres->engine_,
|
||||
/*convert_successfully=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Calibration failed: " << s;
|
||||
cres->calibrator_->setDone(); // Ignore further pushes
|
||||
}
|
||||
VLOG(1) << "Calibration loop terminated " << label;
|
||||
}));
|
||||
VLOG(1) << "initialized calibrator resource";
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
|
||||
|
||||
} // namespace tensorrt
|
||||
|
@ -19,9 +19,14 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/utils.h"
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
@ -30,32 +35,95 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
class Logger;
|
||||
|
||||
class TRTInt8Calibrator;
|
||||
class TRTCalibrationResource;
|
||||
class AsyncHelper;
|
||||
// TODO(Sami): Remove this file?
|
||||
class TRTEngineOp : public OpKernel {
|
||||
|
||||
// This OP can construct TRTEngine on the fly and if construction of engine
|
||||
// fails, executes equivalent subgraph as a TensorFlow function.
|
||||
class TRTEngineOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit TRTEngineOp(OpKernelConstruction* context);
|
||||
|
||||
void Compute(OpKernelContext* context) override;
|
||||
void ComputeAsync(OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) override;
|
||||
~TRTEngineOp();
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
struct Destroyer {
|
||||
void operator()(T* d) { d->destroy(); }
|
||||
};
|
||||
// Execute calibration
|
||||
void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
|
||||
|
||||
// Construct a function handle for executing native funcdef graph
|
||||
Status ConstructFunctionHandle(OpKernelContext* ctx);
|
||||
|
||||
// Execute replaced native segment as function Op.
|
||||
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
|
||||
|
||||
// Allocate necessary resources for calibration
|
||||
Status AllocateCalibrationResources(OpKernelContext* ctx,
|
||||
TRTCalibrationResource** cr);
|
||||
|
||||
template <typename T>
|
||||
using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
|
||||
destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
|
||||
// TODO(samikama): context should go to a resource manager!
|
||||
destroyed_ptr<nvinfer1::IExecutionContext> trt_execution_context_ptr_;
|
||||
typedef std::pair<TrtUniquePtrType<nvinfer1::ICudaEngine>,
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext>>
|
||||
EngineCtxPair;
|
||||
EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx);
|
||||
|
||||
// Return engine batch closest to input batch.
|
||||
int GetEngineBatch(OpKernelContext* ctx);
|
||||
|
||||
nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx);
|
||||
|
||||
// map to keep engines and their execution context for given batch size.
|
||||
std::unordered_map<int, EngineCtxPair> engine_map_;
|
||||
std::vector<string> input_nodes_;
|
||||
std::vector<string> output_nodes_;
|
||||
std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
|
||||
string serialized_engine_;
|
||||
|
||||
// keep device allocator for TRT.
|
||||
std::unique_ptr<TRTDeviceAllocator> allocator_;
|
||||
|
||||
// serialized protobuf segment or trt engine depending on static_engine_ flag.
|
||||
string serialized_segment_;
|
||||
|
||||
// Name of the function for TF native execution of the segment.
|
||||
string funcdef_name_;
|
||||
|
||||
// GraphDef representation of the segment.
|
||||
GraphDef segment_graph_;
|
||||
|
||||
// Lookup table for temporary staging areas of input tensors for calibration.
|
||||
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
|
||||
|
||||
// Temporary staging areas for calibration inputs.
|
||||
std::vector<PersistentTensor> dev_tensors_;
|
||||
|
||||
// Engine Precision mode.
|
||||
int precision_mode_;
|
||||
|
||||
// Whether engine is constructed during the conversion or needs to be
|
||||
// constructed from protobuf segment.
|
||||
bool static_engine_;
|
||||
|
||||
// Whether to calibrate INT8 engine.
|
||||
bool calibration_mode_;
|
||||
|
||||
// Whether non-batch ranks of the inputs are assumed to be fixed or not for
|
||||
// engine construction.
|
||||
bool fixed_input_size_;
|
||||
|
||||
// Batches of the cached engines
|
||||
std::vector<int> cached_engine_batches_;
|
||||
|
||||
// Maximum number of cached engines
|
||||
int max_cached_engines_;
|
||||
|
||||
int64 workspace_size_;
|
||||
mutex engine_mutex_;
|
||||
FunctionLibraryRuntime::Handle native_func_;
|
||||
|
||||
// The finalized calibrator for inference.
|
||||
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
|
@ -28,11 +28,19 @@ extern Status TRTEngineOpShapeInference(InferenceContext* c);
|
||||
}
|
||||
|
||||
REGISTER_OP("TRTEngineOp")
|
||||
.Attr("serialized_engine: string")
|
||||
.Attr("input_nodes: list(string)")
|
||||
.Attr("output_nodes: list(string)")
|
||||
.Attr("InT: list({float32})")
|
||||
.Attr("OutT: list({float32})")
|
||||
.Attr("serialized_segment: string")
|
||||
.Attr("input_shapes: list(shape)")
|
||||
.Attr("output_shapes: list(shape)")
|
||||
.Attr("segment_funcdef_name: string")
|
||||
.Attr("InT: list({int8,float16,float32})")
|
||||
.Attr("OutT: list({int8,float16,float32})")
|
||||
.Attr("static_engine: bool = true")
|
||||
.Attr("fixed_input_size: bool = true")
|
||||
.Attr("cached_engine_batches: list(int) = []")
|
||||
.Attr("max_cached_engines_count: int = 1")
|
||||
.Attr("workspace_size_bytes: int")
|
||||
.Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}")
|
||||
.Attr("calibration_data: string = ''")
|
||||
.Input("in_tensor: InT")
|
||||
.Output("out_tensor: OutT")
|
||||
.SetShapeFn(shape_inference::TRTEngineOpShapeInference);
|
||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
import six as _six
|
||||
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
|
||||
from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
||||
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
|
||||
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
@ -29,7 +31,9 @@ from tensorflow.python.framework import errors_impl as _impl
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
|
||||
@ -40,7 +44,10 @@ def create_inference_graph(input_graph_def,
|
||||
max_batch_size=1,
|
||||
max_workspace_size_bytes=2 << 20,
|
||||
precision_mode="FP32",
|
||||
minimum_segment_size=3):
|
||||
minimum_segment_size=3,
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
cached_engine_batches=[]):
|
||||
"""Python wrapper for the TRT transformation.
|
||||
|
||||
Args:
|
||||
@ -51,6 +58,10 @@ def create_inference_graph(input_graph_def,
|
||||
precision_mode: one of 'FP32', 'FP16' and 'INT8'
|
||||
minimum_segment_size: the minimum number of nodes required for a subgraph to
|
||||
be replaced by TRTEngineOp.
|
||||
is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
|
||||
network and engine at run time.
|
||||
maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
|
||||
cached_engine_batches: batch sizes used to pre-create cached engines.
|
||||
|
||||
Returns:
|
||||
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
|
||||
@ -65,6 +76,30 @@ def create_inference_graph(input_graph_def,
|
||||
"It should be one of {}").format(
|
||||
precision_mode, "{'FP32', 'FP16', 'INT8'}"))
|
||||
mode = supported_precision_modes[precision_mode.upper()]
|
||||
compiled_version = get_linked_tensorrt_version()
|
||||
loaded_version = get_loaded_tensorrt_version()
|
||||
version_mismatch = False
|
||||
if loaded_version[0] < compiled_version[0]:
|
||||
tf_logging.error(
|
||||
"TensorRT version mismatch. Tensorflow was compiled against " +
|
||||
"TensorRT %s but library loaded from environment is TensorRT %s" %
|
||||
(".".join([str(x) for x in compiled_version]),
|
||||
".".join([str(x) for x in loaded_version])) +
|
||||
". Please make sure that correct version of TensorRT " +
|
||||
"is available in the system and added to ldconfig or LD_LIBRARY_PATH"
|
||||
)
|
||||
raise RuntimeError("Incompatible TensorRT library version")
|
||||
for i in zip(loaded_version, compiled_version):
|
||||
if i[0] != i[1]:
|
||||
tf_logging.warn("TensorRT mismatch. Compiled against version " +
|
||||
"%s, but loaded %s. Things may not work" %
|
||||
(".".join([str(x) for x in compiled_version]),
|
||||
".".join([str(x) for x in loaded_version])))
|
||||
version_mismatch = True
|
||||
break
|
||||
if not version_mismatch:
|
||||
tf_logging.info("Running against TensorRT version %s" % ".".join(
|
||||
[str(x) for x in loaded_version]))
|
||||
|
||||
def py2bytes(inp):
|
||||
return inp
|
||||
@ -100,7 +135,9 @@ def create_inference_graph(input_graph_def,
|
||||
# pair or strings where first one is encoded status and the second
|
||||
# one is the transformed graphs protobuf string.
|
||||
out = trt_convert(input_graph_def_str, out_names, max_batch_size,
|
||||
max_workspace_size_bytes, mode, minimum_segment_size)
|
||||
max_workspace_size_bytes, mode, minimum_segment_size,
|
||||
is_dynamic_op, maximum_cached_engines,
|
||||
cached_engine_batches)
|
||||
status = to_string(out[0])
|
||||
output_graph_def_string = out[1]
|
||||
del input_graph_def_str # Save some memory
|
||||
@ -120,11 +157,12 @@ def create_inference_graph(input_graph_def,
|
||||
return output_graph_def
|
||||
|
||||
|
||||
def calib_graph_to_infer_graph(calibration_graph_def):
|
||||
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
|
||||
"""Convert an existing calibration graph to inference graph.
|
||||
|
||||
Args:
|
||||
calibration_graph_def: the calibration GraphDef object with calibration data
|
||||
is_dynamic_op: whether to create dynamic static engines from calibration
|
||||
Returns:
|
||||
New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
|
||||
Raises:
|
||||
@ -141,9 +179,16 @@ def calib_graph_to_infer_graph(calibration_graph_def):
|
||||
to_string = py2string
|
||||
else:
|
||||
to_string = py3string
|
||||
|
||||
is_calib_graph = False
|
||||
for n in calibration_graph_def.node:
|
||||
if n.op == "TRTEngineOp":
|
||||
is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
|
||||
if not is_calib_graph:
|
||||
tf_logging.error(
|
||||
"Not a calib graph. Doesn't seem to contain any calibration nodes.")
|
||||
return None
|
||||
graph_str = calibration_graph_def.SerializeToString()
|
||||
out = calib_convert(graph_str)
|
||||
out = calib_convert(graph_str, is_dynamic_op)
|
||||
status = to_string(out[0])
|
||||
output_graph_def_string = out[1]
|
||||
del graph_str # Save some memory
|
||||
|
@ -50,7 +50,7 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator)
|
||||
}
|
||||
|
||||
void TRTDeviceAllocator::free(void* memory) {
|
||||
VLOG(2) << "Deallocating " << memory;
|
||||
VLOG(2) << "Deallocating @ " << memory;
|
||||
allocator_->DeallocateRaw(memory);
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
|
||||
#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
|
||||
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
|
||||
@ -52,7 +51,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator {
|
||||
// Allocator implementation wrapping TF device allocators.
|
||||
public:
|
||||
TRTDeviceAllocator(tensorflow::Allocator* allocator);
|
||||
virtual ~TRTDeviceAllocator() {}
|
||||
virtual ~TRTDeviceAllocator() {
|
||||
VLOG(1) << "Destroying allocator attached to " << allocator_->Name();
|
||||
}
|
||||
void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override;
|
||||
void free(void* memory) override;
|
||||
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -37,15 +36,22 @@ TRTInt8Calibrator::TRTInt8Calibrator(
|
||||
: batch_size_(batch_size),
|
||||
done_(false),
|
||||
dev_buffers_(dev_buffers),
|
||||
calib_running_(false),
|
||||
calib_running_(true),
|
||||
batch_is_set_(false),
|
||||
engine_name_(engine_name) {}
|
||||
|
||||
TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
|
||||
: batch_size_(0),
|
||||
done_(false),
|
||||
calib_running_(false),
|
||||
batch_is_set_(false),
|
||||
calibration_table_(calib_data) {}
|
||||
|
||||
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
|
||||
const cudaStream_t stream) {
|
||||
tensorflow::mutex_lock lock(cond_mtx_);
|
||||
while ((calib_running_ || batch_is_set_) &&
|
||||
!done_) { // wait while calibration is running
|
||||
// wait while calibration is running.
|
||||
while ((calib_running_ || batch_is_set_) && !done_) {
|
||||
cond_.wait(lock);
|
||||
}
|
||||
if (done_) return false;
|
||||
@ -59,8 +65,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
|
||||
}
|
||||
const auto& d = devptr->second;
|
||||
|
||||
// TODO(aaroey): we should not use sync copy on default stream. Make sure
|
||||
// stream->ThenMemcpy() is used in future PRs.
|
||||
// TODO(sami,aaroey): Need to figure out a way to ensure synchronization
|
||||
// between stream, perhaps using a tensor?
|
||||
auto status = cudaMemcpyAsync(d.first, it.second, d.second,
|
||||
@ -84,13 +88,11 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
|
||||
tensorflow::mutex_lock lock(cond_mtx_);
|
||||
calib_running_ = false;
|
||||
cond_.notify_all();
|
||||
while ((!batch_is_set_ && !done_)) { // wait until new batch arrives
|
||||
// wait until new batch arrives
|
||||
while ((!batch_is_set_ && !done_)) {
|
||||
cond_.wait(lock);
|
||||
|
||||
}
|
||||
if (done_) {
|
||||
return false;
|
||||
}
|
||||
if (done_) return false;
|
||||
|
||||
for (int i = 0; i < num_bindings; i++) {
|
||||
auto it = dev_buffers_.find(names[i]);
|
||||
@ -107,7 +109,9 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
|
||||
}
|
||||
|
||||
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
|
||||
return nullptr;
|
||||
if (calibration_table_.empty()) return nullptr;
|
||||
length = calibration_table_.size();
|
||||
return calibration_table_.data();
|
||||
}
|
||||
|
||||
void TRTInt8Calibrator::setDone() {
|
||||
@ -117,7 +121,11 @@ void TRTInt8Calibrator::setDone() {
|
||||
}
|
||||
|
||||
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
|
||||
std::size_t length) {}
|
||||
std::size_t length) {
|
||||
calibration_table_ = string((const char*)ptr, length);
|
||||
VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr
|
||||
<< " length=" << length;
|
||||
}
|
||||
TRTInt8Calibrator::~TRTInt8Calibrator() {
|
||||
VLOG(1) << "Destroying calibrator for " << engine_name_;
|
||||
}
|
||||
|
@ -39,29 +39,48 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
|
||||
TRTInt8Calibrator(
|
||||
const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
|
||||
int batch_size, string engine_name);
|
||||
|
||||
TRTInt8Calibrator(const string& calibration_data);
|
||||
|
||||
~TRTInt8Calibrator();
|
||||
|
||||
int getBatchSize() const override;
|
||||
|
||||
bool getBatch(void* bindings[], const char* names[],
|
||||
int num_bindings) override;
|
||||
|
||||
bool setBatch(const std::unordered_map<string, void*>& data,
|
||||
const cudaStream_t stream);
|
||||
|
||||
void setDone();
|
||||
|
||||
// If not null, calibration is skipped.
|
||||
const void* readCalibrationCache(std::size_t& length) override;
|
||||
|
||||
void writeCalibrationCache(const void* ptr, std::size_t length) override;
|
||||
~TRTInt8Calibrator();
|
||||
|
||||
const string& getCalibrationTableAsString() { return calibration_table_; }
|
||||
|
||||
private:
|
||||
const int batch_size_;
|
||||
tensorflow::mutex cond_mtx_; // mutex for condition_variable
|
||||
tensorflow::condition_variable cond_; // condition variable to implement
|
||||
// producer-consumer queue for
|
||||
// calibration
|
||||
|
||||
// mutex for condition_variable
|
||||
tensorflow::mutex cond_mtx_;
|
||||
|
||||
// condition variable to implement producer-consumer queue for calibration
|
||||
tensorflow::condition_variable cond_;
|
||||
|
||||
// Is calibration finished?
|
||||
bool done_;
|
||||
const std::unordered_map<string, std::pair<void*, size_t>>
|
||||
dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with
|
||||
// buffer names
|
||||
|
||||
// Map to keep tensorrt input buffers and sizes keyed with buffer names
|
||||
const std::unordered_map<string, std::pair<void*, size_t>> dev_buffers_;
|
||||
|
||||
bool calib_running_;
|
||||
bool batch_is_set_;
|
||||
|
||||
string engine_name_;
|
||||
string calibration_table_;
|
||||
};
|
||||
|
||||
} // namespace tensorrt
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/tensorrt/convert/utils.h"
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
|
||||
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
|
||||
@ -34,50 +35,48 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class TRTCalibrationResource : public tensorflow::ResourceBase {
|
||||
public:
|
||||
TRTCalibrationResource()
|
||||
: calibrator_(nullptr),
|
||||
builder_(nullptr),
|
||||
network_(nullptr),
|
||||
engine_(nullptr),
|
||||
logger_(nullptr),
|
||||
thr_(nullptr) {}
|
||||
|
||||
~TRTCalibrationResource() {
|
||||
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
|
||||
builder_.reset();
|
||||
engine_.reset();
|
||||
// We need to manually destroy the builder and engine before the allocator
|
||||
// is destroyed.
|
||||
allocator_.reset();
|
||||
}
|
||||
|
||||
string DebugString() override {
|
||||
std::stringstream oss;
|
||||
oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl
|
||||
<< " Builder = " << std::hex << builder_ << std::dec << std::endl
|
||||
<< " Network = " << std::hex << network_ << std::dec << std::endl
|
||||
<< " Engine = " << std::hex << engine_ << std::dec << std::endl
|
||||
<< " Logger = " << std::hex << logger_ << std::dec << std::endl
|
||||
<< " Allocator = " << std::hex << allocator_.get() << std::dec
|
||||
<< std::endl
|
||||
<< " Thread = " << std::hex << thr_ << std::dec << std::endl;
|
||||
using std::dec;
|
||||
using std::endl;
|
||||
using std::hex;
|
||||
oss << " Calibrator = " << hex << calibrator_.get() << dec << endl
|
||||
<< " Builder = " << hex << builder_.get() << dec << endl
|
||||
<< " Engine = " << hex << engine_.get() << dec << endl
|
||||
<< " Logger = " << hex << &logger_ << dec << endl
|
||||
<< " Allocator = " << hex << allocator_.get() << dec << endl
|
||||
<< " Thread = " << hex << thr_.get() << dec << endl;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
TRTInt8Calibrator* calibrator_;
|
||||
nvinfer1::IBuilder* builder_;
|
||||
nvinfer1::INetworkDefinition* network_;
|
||||
nvinfer1::ICudaEngine* engine_;
|
||||
std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
|
||||
tensorflow::tensorrt::Logger* logger_;
|
||||
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
||||
std::unique_ptr<nvinfer1::IGpuAllocator> allocator_;
|
||||
tensorflow::tensorrt::Logger logger_;
|
||||
// TODO(sami): Use threadpool threads!
|
||||
std::thread* thr_;
|
||||
std::unique_ptr<std::thread> thr_;
|
||||
};
|
||||
|
||||
class TRTWeightStore : public tensorflow::ResourceBase {
|
||||
class TRTWeightStore {
|
||||
public:
|
||||
TRTWeightStore() {}
|
||||
|
||||
virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); }
|
||||
|
||||
string DebugString() override {
|
||||
string DebugString() {
|
||||
std::stringstream oss;
|
||||
size_t len_bytes = 0;
|
||||
for (const auto& v : store_) {
|
||||
|
@ -29,8 +29,9 @@ namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
namespace segment {
|
||||
|
||||
// vector of segments, each entry contains a device name and a set of nodes in
|
||||
// segment
|
||||
// Vector of segments, each entry contains a set of node names and a device name
|
||||
// in the segment.
|
||||
// TODO(aaroey): use node pointer instead of node name.
|
||||
using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>;
|
||||
|
||||
struct SegmentOptions {
|
||||
@ -48,6 +49,8 @@ struct SegmentOptions {
|
||||
// in the vector describes a subgraph by giving a set of the names of
|
||||
// all the NodeDefs in that subgraph.
|
||||
// @return the status.
|
||||
//
|
||||
// TODO(aaroey): remove this method.
|
||||
tensorflow::Status SegmentGraph(
|
||||
const tensorflow::GraphDef& gdef,
|
||||
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
|
||||
|
@ -29,61 +29,35 @@ namespace tensorflow {
|
||||
namespace shape_inference {
|
||||
|
||||
tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
|
||||
tensorflow::tensorrt::Logger logger;
|
||||
string serialized_engine;
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine));
|
||||
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
|
||||
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
|
||||
serialized_engine.c_str(), serialized_engine.size(),
|
||||
tensorrt::PluginFactoryTensorRT::GetInstance());
|
||||
|
||||
int num_batch = -1;
|
||||
std::vector<::tensorflow::DataType> input_type;
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type));
|
||||
for (size_t i = 0; i < context->num_inputs(); i++) {
|
||||
// Check if input shape is legit
|
||||
auto input_shape = context->input(i);
|
||||
for (int j = 0; j < context->Rank(input_shape); j++) {
|
||||
auto dim_handler = context->Dim(input_shape, j);
|
||||
if (j == 0) {
|
||||
if (i == 0) {
|
||||
num_batch = context->Value(dim_handler);
|
||||
} else if (num_batch != context->Value(dim_handler)) {
|
||||
// TODO(jie): TensorRT engine requires consistent batch between inputs
|
||||
// tensors. Segmenter should be aware of this.
|
||||
LOG(FATAL) << "TensorRT engine requires consistent batch size";
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<tensorflow::TensorShape> shapes;
|
||||
for (int i = 0; i < context->num_outputs(); ++i) {
|
||||
context->set_output(i, context->UnknownShape());
|
||||
}
|
||||
|
||||
// Arrange input here
|
||||
std::vector<string> input_nodes;
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes));
|
||||
|
||||
// Arrange output here
|
||||
std::vector<string> output_nodes;
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes));
|
||||
for (size_t i = 0; i < output_nodes.size(); i++) {
|
||||
int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str());
|
||||
ShapeHandle output_shape;
|
||||
std::vector<DimensionHandle> dim_vec;
|
||||
dim_vec.emplace_back(context->MakeDim(num_batch));
|
||||
if (binding_index != -1) {
|
||||
auto dims = trt_engine->getBindingDimensions(binding_index);
|
||||
for (int j = 0; j < dims.nbDims; j++) {
|
||||
dim_vec.emplace_back(context->MakeDim(dims.d[j]));
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i];
|
||||
}
|
||||
output_shape = context->MakeShape(dim_vec);
|
||||
context->set_output(i, output_shape);
|
||||
auto status = context->GetAttr("input_shapes", &shapes);
|
||||
// it is ok to not to have shapes
|
||||
if (!status.ok()) return Status::OK();
|
||||
if ((int)shapes.size() != context->num_inputs()) return Status::OK();
|
||||
bool different_input = false;
|
||||
for (int i = 0; i < context->num_inputs(); ++i) {
|
||||
if (shapes.at(i) != context->input_tensor(i)->shape())
|
||||
different_input = true;
|
||||
}
|
||||
if (different_input) return Status::OK();
|
||||
shapes.resize(0);
|
||||
status = context->GetAttr("output_shapes", &shapes);
|
||||
if (!status.ok()) return Status::OK();
|
||||
if ((int)shapes.size() != context->num_outputs()) return Status::OK();
|
||||
std::vector<ShapeHandle> shape_handles(shapes.size());
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
status =
|
||||
context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i));
|
||||
if (!status.ok()) return Status::OK();
|
||||
}
|
||||
for (int i = 0; i < context->num_outputs(); ++i) {
|
||||
context->set_output(i, shape_handles.at(i));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import six as _six
|
||||
|
||||
# normally we should do import tensorflow as tf and then
|
||||
# tf.placeholder, tf.constant, tf.nn.conv2d etc but
|
||||
@ -35,10 +36,75 @@ from tensorflow.python.framework import dtypes as dtypes
|
||||
from tensorflow.python.framework import importer as importer
|
||||
from tensorflow.python.framework import ops as ops
|
||||
from tensorflow.python.ops import array_ops as aops
|
||||
from tensorflow.python.ops import math_ops as mops
|
||||
from tensorflow.python.ops import nn as nn
|
||||
from tensorflow.python.ops import nn_ops as nn_ops
|
||||
|
||||
|
||||
def py2bytes(inp):
|
||||
return inp
|
||||
|
||||
|
||||
def py3bytes(inp):
|
||||
return inp.encode("utf-8", errors="surrogateescape")
|
||||
|
||||
|
||||
def py2string(inp):
|
||||
return inp
|
||||
|
||||
|
||||
def py3string(inp):
|
||||
return inp.decode("utf-8")
|
||||
|
||||
|
||||
if _six.PY2:
|
||||
to_bytes = py2bytes
|
||||
to_string = py2string
|
||||
else:
|
||||
to_bytes = py3bytes
|
||||
to_string = py3string
|
||||
|
||||
|
||||
def get_multi_engine_graph_def(mode="FP32"):
|
||||
"""Create a simple graph and return its graph_def."""
|
||||
dtype = dtypes.float32
|
||||
if mode.upper() == "FP16":
|
||||
dtype = dtypes.float16
|
||||
else:
|
||||
pass
|
||||
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype)
|
||||
with g.name_scope("Global_scope"):
|
||||
with g.name_scope("first_scope"):
|
||||
e = cop.constant(
|
||||
np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype)
|
||||
conv = nn.conv2d(
|
||||
input=x,
|
||||
filter=e,
|
||||
data_format="NCHW",
|
||||
strides=[1, 1, 1, 1],
|
||||
padding="VALID",
|
||||
name="conv")
|
||||
b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype)
|
||||
t = conv * b
|
||||
|
||||
b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype)
|
||||
q = conv / b
|
||||
edge = mops.sin(q)
|
||||
edge1 = mops.cos(conv)
|
||||
with g.name_scope("test_scope"):
|
||||
de = edge + edge1
|
||||
t -= edge1
|
||||
q *= edge
|
||||
t += q
|
||||
t -= de
|
||||
k = aops.squeeze(t, name="output")
|
||||
print(k.dtype)
|
||||
return g.as_graph_def()
|
||||
|
||||
|
||||
def get_simple_graph_def():
|
||||
"""Create a simple graph and return its graph_def."""
|
||||
g = ops.Graph()
|
||||
@ -65,7 +131,9 @@ def get_simple_graph_def():
|
||||
def execute_graph(gdef, dumm_inp):
|
||||
"""Run given graphdef once."""
|
||||
print("executing")
|
||||
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
gpu_options = None
|
||||
if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
|
||||
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
sessconfig = cpb2.ConfigProto(gpu_options=gpu_options)
|
||||
ops.reset_default_graph()
|
||||
g = ops.Graph()
|
||||
@ -83,7 +151,9 @@ def execute_graph(gdef, dumm_inp):
|
||||
# for calibration. For this test script it is random data.
|
||||
def execute_calibration(gdef, dumm_inp):
|
||||
"""Run given calibration graph multiple times."""
|
||||
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
gpu_options = None
|
||||
if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
|
||||
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
ops.reset_default_graph()
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -100,12 +170,17 @@ def execute_calibration(gdef, dumm_inp):
|
||||
return val
|
||||
|
||||
|
||||
def user(run_graph=execute_graph, run_calibration=execute_calibration):
|
||||
def user(multi_engine,
|
||||
run_graph=execute_graph,
|
||||
run_calibration=execute_calibration):
|
||||
"""Example function that converts a graph to TFTRT graph."""
|
||||
|
||||
inp_dims = (100, 24, 24, 2)
|
||||
if multi_engine:
|
||||
inp_dims = (2, 3, 7, 5)
|
||||
orig_graph = get_multi_engine_graph_def()
|
||||
else:
|
||||
inp_dims = (100, 24, 24, 2)
|
||||
orig_graph = get_simple_graph_def() # use a frozen graph for inference
|
||||
dummy_input = np.random.random_sample(inp_dims)
|
||||
orig_graph = get_simple_graph_def() # use a frozen graph for inference
|
||||
# Get optimized graph
|
||||
trt_graph = trt.create_inference_graph(
|
||||
input_graph_def=orig_graph,
|
||||
@ -113,8 +188,10 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration):
|
||||
max_batch_size=inp_dims[0],
|
||||
max_workspace_size_bytes=1 << 25,
|
||||
precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8"
|
||||
minimum_segment_size=2 # minimum number of nodes in an engine
|
||||
)
|
||||
minimum_segment_size=2, # minimum number of nodes in an engine
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
cached_engine_batches=[])
|
||||
o1 = run_graph(orig_graph, dummy_input)
|
||||
o2 = run_graph(trt_graph, dummy_input)
|
||||
o3 = run_graph(trt_graph, dummy_input)
|
||||
@ -126,40 +203,51 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration):
|
||||
max_batch_size=inp_dims[0],
|
||||
max_workspace_size_bytes=1 << 25,
|
||||
precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8"
|
||||
minimum_segment_size=2 # minimum number of nodes in an engine
|
||||
)
|
||||
minimum_segment_size=2, # minimum number of nodes in an engine
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
cached_engine_batches=[])
|
||||
int8_calib_gdef = trt.create_inference_graph(
|
||||
input_graph_def=orig_graph,
|
||||
outputs=["output"],
|
||||
max_batch_size=inp_dims[0],
|
||||
max_workspace_size_bytes=1 << 25,
|
||||
precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8"
|
||||
minimum_segment_size=2 # minimum number of nodes in an engine
|
||||
)
|
||||
minimum_segment_size=2, # minimum number of nodes in an engine
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
cached_engine_batches=[])
|
||||
o4 = run_graph(fp16_graph, dummy_input)
|
||||
_ = run_calibration(int8_calib_gdef, dummy_input)
|
||||
int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
|
||||
o5 = run_graph(int8_graph, dummy_input)
|
||||
assert np.allclose(o1, o4)
|
||||
assert np.allclose(o1, o5)
|
||||
print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4))
|
||||
print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5))
|
||||
print("Pass")
|
||||
|
||||
|
||||
def auto():
|
||||
def auto(multi_engine):
|
||||
"""Run the conversion as an optimization pass."""
|
||||
inp_dims = (100, 24, 24, 2)
|
||||
if multi_engine:
|
||||
inp_dims = (2, 3, 7, 5)
|
||||
orig_graph = get_multi_engine_graph_def()
|
||||
else:
|
||||
inp_dims = (100, 24, 24, 2)
|
||||
orig_graph = get_simple_graph_def() # use a frozen graph for inference
|
||||
dummy_input = np.random.random_sample(inp_dims)
|
||||
orig_graph = get_simple_graph_def()
|
||||
opt_config = rwpb2.RewriterConfig()
|
||||
opt_config.meta_optimizer_iterations = opt_config.ONE
|
||||
opt_config.optimizers.extend(["constfold", "layout"])
|
||||
custom_op = opt_config.custom_optimizers.add()
|
||||
custom_op.name = "TensorRTOptimizer"
|
||||
custom_op.parameter_map["minimum_segment_size"].i = 3
|
||||
custom_op.parameter_map["precision_mode"].s = "FP32"
|
||||
custom_op.parameter_map["precision_mode"].s = to_bytes("FP32")
|
||||
custom_op.parameter_map["max_batch_size"].i = inp_dims[0]
|
||||
custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
|
||||
print(custom_op)
|
||||
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
gpu_options = None
|
||||
if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
|
||||
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
graph_options = cpb2.GraphOptions(rewrite_options=opt_config)
|
||||
sessconfig = cpb2.ConfigProto(
|
||||
gpu_options=gpu_options, graph_options=graph_options)
|
||||
@ -168,7 +256,7 @@ def auto():
|
||||
ops.reset_default_graph()
|
||||
with g.as_default():
|
||||
inp, out = importer.import_graph_def(
|
||||
graph_def=orig_graph, return_elements=["input", "output"])
|
||||
graph_def=orig_graph, return_elements=["input", "output"], name="")
|
||||
inp = inp.outputs[0]
|
||||
out = out.outputs[0]
|
||||
with csess.Session(config=sessconfig, graph=g) as sess:
|
||||
@ -186,8 +274,14 @@ if "__main__" in __name__:
|
||||
action="store_true",
|
||||
help="Do TRT conversion automatically",
|
||||
default=False)
|
||||
P.add_argument(
|
||||
"--multi-engine",
|
||||
"-m",
|
||||
action="store_true",
|
||||
help="Use a graph that will result in 2 engines",
|
||||
default=False)
|
||||
flags, unparsed = P.parse_known_args()
|
||||
if flags.automatic:
|
||||
auto()
|
||||
auto(flags.multi_engine)
|
||||
else:
|
||||
user()
|
||||
user(flags.multi_engine)
|
||||
|
@ -48,12 +48,53 @@ PyObject* pair_helper(std::pair<string, string>* in) {
|
||||
}
|
||||
return tuple;
|
||||
}
|
||||
|
||||
struct version_struct{
|
||||
int vmajor;
|
||||
int vminor;
|
||||
int vpatch;
|
||||
};
|
||||
|
||||
PyObject* version_helper(version_struct* in) {
|
||||
PyObject *tuple(nullptr);
|
||||
tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch);
|
||||
if (!tuple) {
|
||||
if (!PyErr_Occurred()) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"Tuple creation from version structure failed!");
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
return tuple;
|
||||
}
|
||||
/* Define converters for vector<int> */
|
||||
template<>
|
||||
bool _PyObjAs(PyObject *pyobj, int* dest) {
|
||||
*dest = PyLong_AsLong(pyobj);
|
||||
return true;
|
||||
}
|
||||
|
||||
template<>
|
||||
PyObject *_PyObjFrom(const int& src) {
|
||||
return PyLong_FromLong(src);
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
|
||||
|
||||
%typemap(out) std::pair<string, string> {
|
||||
PyObject *tuple = pair_helper(&$1);
|
||||
if (!tuple) SWIG_fail;
|
||||
$result = tuple;
|
||||
}
|
||||
|
||||
%typemap(out) version_struct {
|
||||
PyObject *tuple = version_helper(&$1);
|
||||
if (!tuple) SWIG_fail;
|
||||
$result = tuple;
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -65,6 +106,8 @@ PyObject* pair_helper(std::pair<string, string>* in) {
|
||||
%unignore tensorflow;
|
||||
%unignore trt_convert;
|
||||
%unignore calib_convert;
|
||||
%unignore get_linked_tensorrt_version;
|
||||
%unignore get_loaded_tensorrt_version;
|
||||
|
||||
%{
|
||||
|
||||
@ -74,7 +117,10 @@ std::pair<string, string> trt_convert(
|
||||
size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes,
|
||||
int precision_mode,
|
||||
int minimum_segment_size
|
||||
int minimum_segment_size,
|
||||
bool is_dyn_op,
|
||||
int max_cached_engines,
|
||||
std::vector<int> cached_engine_batches
|
||||
// Unfortunately we can't use TF_Status here since it
|
||||
// is in c/c_api and brings in a lot of other libraries
|
||||
// which in turn declare ops. These ops are included
|
||||
@ -102,11 +148,12 @@ std::pair<string, string> trt_convert(
|
||||
out_status = "InvalidArgument;Size of the output_names vector is 0";
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
tensorflow::GraphDef outGraph;
|
||||
tensorflow::GraphDef out_graph;
|
||||
tensorflow::Status conversion_status =
|
||||
tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
|
||||
graph_def, output_names, max_batch_size, max_workspace_size_bytes,
|
||||
&outGraph, precision_mode, minimum_segment_size);
|
||||
&out_graph, precision_mode, minimum_segment_size,
|
||||
is_dyn_op, max_cached_engines, cached_engine_batches);
|
||||
if (!conversion_status.ok()) {
|
||||
auto retCode = (int)conversion_status.code();
|
||||
char buff[2000];
|
||||
@ -116,7 +163,7 @@ std::pair<string, string> trt_convert(
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
string result;
|
||||
if (!outGraph.SerializeToString(&result)) {
|
||||
if (!out_graph.SerializeToString(&result)) {
|
||||
out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
@ -128,7 +175,8 @@ std::pair<string, string> trt_convert(
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
}
|
||||
|
||||
std::pair<string, string> calib_convert(string graph_def_string // const tensorflow::GraphDef&
|
||||
std::pair<string, string> calib_convert(
|
||||
string graph_def_string, bool is_dyn_op
|
||||
// unfortunately we can't use TF_Status here since it
|
||||
// is in c/c_api and brings in a lot of other libraries
|
||||
// which in turn declare ops. These ops are included
|
||||
@ -147,11 +195,11 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
|
||||
out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
|
||||
tensorflow::GraphDef outGraph;
|
||||
graph_def_string.resize(0);
|
||||
tensorflow::GraphDef out_graph;
|
||||
tensorflow::Status conversion_status =
|
||||
tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def,
|
||||
&outGraph);
|
||||
tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(
|
||||
graph_def, &out_graph, is_dyn_op);
|
||||
if (!conversion_status.ok()) {
|
||||
auto retCode = (int)conversion_status.code();
|
||||
char buff[2000];
|
||||
@ -161,7 +209,7 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
string result;
|
||||
if (!outGraph.SerializeToString(&result)) {
|
||||
if (!out_graph.SerializeToString(&result)) {
|
||||
out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
|
||||
return std::pair<string, string>{out_status, ""};
|
||||
}
|
||||
@ -172,15 +220,39 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
|
||||
return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
}
|
||||
|
||||
version_struct get_linked_tensorrt_version(){
|
||||
// Return the version at the link time.
|
||||
const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion();
|
||||
version_struct s;
|
||||
s.vmajor = lv[0];
|
||||
s.vminor = lv[1];
|
||||
s.vpatch = lv[2];
|
||||
return s;
|
||||
}
|
||||
version_struct get_loaded_tensorrt_version(){
|
||||
// Return the version from the loaded library.
|
||||
const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion();
|
||||
version_struct s;
|
||||
s.vmajor = lv[0];
|
||||
s.vminor = lv[1];
|
||||
s.vpatch = lv[2];
|
||||
return s;
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
std::pair<string, string> calib_convert(string graph_def_string);
|
||||
std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
|
||||
|
||||
std::pair<string, string> trt_convert(string graph_def_string,
|
||||
std::vector<string> output_names,
|
||||
size_t max_batch_size,
|
||||
size_t max_workspace_size_bytes,
|
||||
int precision_mode, int minimum_segment_size);
|
||||
|
||||
int precision_mode, int minimum_segment_size,
|
||||
bool is_dyn_op,
|
||||
int max_cached_engines,
|
||||
std::vector<int> cached_engine_batches);
|
||||
version_struct get_linked_tensorrt_version();
|
||||
version_struct get_loaded_tensorrt_version();
|
||||
|
||||
%unignoreall
|
||||
|
@ -49,11 +49,11 @@ tf_cc_binary(
|
||||
":tpu_profiler_analysis_proto_cc",
|
||||
":tpu_profiler_proto_cc",
|
||||
":version",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/platform/cloud:gcs_file_system",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -53,12 +53,12 @@ cc_library(
|
||||
":grpc_verbs_service_impl",
|
||||
":rdma_mgr",
|
||||
":verbs_service_proto_cc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -69,7 +69,7 @@ cc_library(
|
||||
hdrs = ["grpc_verbs_service_impl.h"],
|
||||
deps = [
|
||||
":verbs_service_proto_cc",
|
||||
"@grpc//:grpc++",
|
||||
"//tensorflow:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
# The following targets can be used to access ApiDefs:
|
||||
# :base_api_def
|
||||
# :python_api_def
|
||||
# :java_api_def
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
@ -29,6 +30,12 @@ filegroup(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "java_api_def",
|
||||
srcs = glob(["java_api/*"]),
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "excluded_ops_lib",
|
||||
srcs = ["excluded_ops.cc"],
|
||||
|
@ -68,7 +68,7 @@ END
|
||||
name: "area_range"
|
||||
description: <<END
|
||||
The cropped area of the image must contain a fraction of the
|
||||
supplied image within in this range.
|
||||
supplied image within this range.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
|
@ -68,7 +68,7 @@ END
|
||||
name: "area_range"
|
||||
description: <<END
|
||||
The cropped area of the image must contain a fraction of the
|
||||
supplied image within in this range.
|
||||
supplied image within this range.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
|
@ -11,7 +11,7 @@ END
|
||||
name: "stride"
|
||||
description: <<END
|
||||
A scalar representing the steps moving the sliding window
|
||||
forward in one iteration. It must be in `[1, window_size)`.
|
||||
forward in one iteration. It must be positive.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that passes a sliding window over `input_dataset`."
|
||||
|
4
tensorflow/core/api_def/java_api/api_def_Assert.pbtxt
Normal file
4
tensorflow/core/api_def/java_api/api_def_Assert.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Assert" #TODO(karllessard) escape that reserved name
|
||||
visibility: HIDDEN
|
||||
}
|
4
tensorflow/core/api_def/java_api/api_def_Const.pbtxt
Normal file
4
tensorflow/core/api_def/java_api/api_def_Const.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Const" #TODO(karllessard) escape that reserved name
|
||||
visibility: HIDDEN
|
||||
}
|
4
tensorflow/core/api_def/java_api/api_def_Switch.pbtxt
Normal file
4
tensorflow/core/api_def/java_api/api_def_Switch.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Switch" #TODO(karllessard) escape that reserved name
|
||||
visibility: HIDDEN
|
||||
}
|
@ -106,24 +106,24 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
|
||||
EXPECT_EQ(1, shape.dim(1).size());
|
||||
if (node->name() == y->name()) {
|
||||
#ifdef INTEL_MKL
|
||||
// if MKL is used, it goes through various additional
|
||||
// graph rewrite pass. In TF, everytime a graph pass
|
||||
// if MKL is used, it goes through various additional
|
||||
// graph rewrite pass. In TF, everytime a graph pass
|
||||
// happens, "constant" nodes are allocated
|
||||
// and deallocated. Each allocation calls the
|
||||
// (FindChunkPtr of BFCAllocator),
|
||||
// which increments the value of AllocationId.
|
||||
// Thus AllocationId becomes more than 3 and 4 if
|
||||
// MKL is used. Now they are 9 and 10 for MKL.
|
||||
EXPECT_EQ(19, cm->AllocationId(node, 0));
|
||||
// which increments the value of AllocationId.
|
||||
// Thus AllocationId becomes more than TF if MKL
|
||||
// is used. Now IDs for MKL are 8 more than TF.
|
||||
EXPECT_EQ(29, cm->AllocationId(node, 0));
|
||||
#else
|
||||
EXPECT_EQ(21, cm->AllocationId(node, 0));
|
||||
#endif
|
||||
#endif
|
||||
} else {
|
||||
#ifdef INTEL_MKL
|
||||
EXPECT_EQ(20, cm->AllocationId(node, 0));
|
||||
EXPECT_EQ(30, cm->AllocationId(node, 0));
|
||||
#else
|
||||
EXPECT_EQ(22, cm->AllocationId(node, 0));
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
EXPECT_LE(0, cm->MaxExecutionTime(node));
|
||||
|
@ -17,6 +17,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
// Declare function to avoid unresolved symbol in VS
|
||||
i_malloc_t i_malloc;
|
||||
i_calloc_t i_calloc;
|
||||
i_realloc_t i_realloc;
|
||||
i_free_t i_free;
|
||||
#endif
|
||||
namespace tensorflow {
|
||||
|
||||
constexpr const char* MklCPUAllocator::kMaxLimitStr;
|
||||
|
@ -143,6 +143,7 @@ tf_cuda_library(
|
||||
":debug_node_key",
|
||||
":debug_service_proto_cc",
|
||||
":debugger_event_metadata_proto_cc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
@ -150,7 +151,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -166,11 +166,11 @@ tf_cuda_library(
|
||||
":debug_io_utils",
|
||||
":debug_service_proto_cc",
|
||||
":debugger_event_metadata_proto_cc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -628,6 +628,7 @@ tf_cuda_cc_test(
|
||||
":master",
|
||||
":remote_device",
|
||||
":worker_interface",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -649,7 +650,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/kernels:dense_update_ops",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -667,6 +667,7 @@ tf_cuda_cc_test(
|
||||
":master",
|
||||
":remote_device",
|
||||
":worker_interface",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -682,7 +683,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -48,6 +48,8 @@ cc_library(
|
||||
"eager_service_impl.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:grpc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -67,8 +69,6 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"@grpc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -41,8 +41,8 @@ cc_library(
|
||||
srcs = ["grpc_util.cc"],
|
||||
hdrs = ["grpc_util.h"],
|
||||
deps = [
|
||||
"@grpc",
|
||||
"@grpc//:grpc++",
|
||||
"//tensorflow:grpc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
# Required to be able to overload TensorResponse parsing.
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
@ -55,8 +55,8 @@ cc_library(
|
||||
hdrs = ["grpc_client_cq_tag.h"],
|
||||
deps = [
|
||||
":grpc_util",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -67,10 +67,10 @@ cc_library(
|
||||
deps = [
|
||||
":grpc_client_cq_tag",
|
||||
":grpc_util",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:call_options",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -83,6 +83,7 @@ cc_library(
|
||||
":grpc_state",
|
||||
":grpc_util",
|
||||
":grpc_worker_service_impl",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -90,7 +91,6 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -100,10 +100,10 @@ cc_library(
|
||||
hdrs = ["grpc_channel.h"],
|
||||
deps = [
|
||||
":grpc_util",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -112,13 +112,13 @@ cc_library(
|
||||
srcs = ["grpc_tensor_coding.cc"],
|
||||
hdrs = ["grpc_tensor_coding.h"],
|
||||
deps = [
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -127,9 +127,9 @@ cc_library(
|
||||
srcs = [],
|
||||
hdrs = ["grpc_call.h"],
|
||||
deps = [
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -167,6 +167,7 @@ tf_cuda_library(
|
||||
":grpc_tensor_coding",
|
||||
":grpc_util",
|
||||
":grpc_worker_service_impl",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -180,7 +181,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_session",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -190,9 +190,9 @@ cc_library(
|
||||
hdrs = ["grpc_worker_service_impl.h"],
|
||||
deps = [
|
||||
":grpc_util",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -220,12 +220,12 @@ cc_library(
|
||||
":async_service_interface",
|
||||
":grpc_call",
|
||||
":grpc_util",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core:master_service_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:master",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -259,6 +259,8 @@ cc_library(
|
||||
":grpc_worker_cache",
|
||||
":grpc_worker_service",
|
||||
":rpc_rendezvous_mgr",
|
||||
"//tensorflow:grpc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -277,8 +279,6 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_wrapper",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl",
|
||||
"@grpc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -299,13 +299,13 @@ tf_cc_binary(
|
||||
],
|
||||
deps = [
|
||||
":grpc_server_lib",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/kernels:data_flow",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -317,6 +317,7 @@ tf_cc_binary(
|
||||
],
|
||||
deps = [
|
||||
":grpc_server_lib",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
@ -330,7 +331,6 @@ tf_cc_binary(
|
||||
"//tensorflow/core/kernels:matmul_op",
|
||||
"//tensorflow/core/kernels:reduction_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -415,6 +415,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":grpc_tensor_coding",
|
||||
":grpc_testlib",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -424,7 +425,6 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -434,11 +434,11 @@ tf_cc_test(
|
||||
srcs = ["grpc_util_test.cc"],
|
||||
deps = [
|
||||
":grpc_util",
|
||||
"//tensorflow:grpc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@grpc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -11,8 +11,8 @@ cc_library(
|
||||
srcs = ["grpc_eager_service.cc"],
|
||||
hdrs = ["grpc_eager_service.h"],
|
||||
deps = [
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:eager_service_proto_cc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -21,6 +21,7 @@ cc_library(
|
||||
srcs = ["grpc_eager_client.cc"],
|
||||
hdrs = ["grpc_eager_client.h"],
|
||||
deps = [
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:eager_service_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
@ -29,7 +30,6 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_state",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -39,6 +39,7 @@ cc_library(
|
||||
hdrs = ["grpc_eager_service_impl.h"],
|
||||
deps = [
|
||||
":grpc_eager_service",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ptr_util",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_service_impl",
|
||||
@ -47,6 +48,6 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
"@grpc//:grpc++",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
||||
],
|
||||
)
|
||||
|
@ -289,6 +289,12 @@ Status GrpcServer::Init(
|
||||
nullptr);
|
||||
}
|
||||
|
||||
Status GrpcServer::Init(
|
||||
ServiceInitFunction service_func,
|
||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
|
||||
return Init(std::move(service_func), rendezvous_mgr_func, nullptr, nullptr);
|
||||
}
|
||||
|
||||
Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); }
|
||||
|
||||
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
|
||||
|
@ -97,6 +97,9 @@ class GrpcServer : public ServerInterface {
|
||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
|
||||
const CollectiveMgrCreationFunction& collective_mgr_func);
|
||||
|
||||
Status Init(ServiceInitFunction service_func,
|
||||
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
|
||||
|
||||
Status Init();
|
||||
|
||||
// A subclass can override this method to support secure credentials.
|
||||
|
@ -1901,6 +1901,11 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
|
||||
|
||||
#else // INTEL_MKL_ML
|
||||
|
||||
// NOTE: Unit tests in this file rely on a topological sorted graph for
|
||||
// printing. But since sibling nodes of a node in the topologically sorted graph
|
||||
// can be printed in different orders, tests may fail if the order in which
|
||||
// sibling nodes are visited is changed.
|
||||
|
||||
namespace {
|
||||
|
||||
const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0";
|
||||
@ -2572,9 +2577,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;"
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
|
||||
"B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
|
||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
|
||||
"B->E:1;C->F;C:control->DMT/_2:control;C:control->DMT/_3:control;"
|
||||
"D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;"
|
||||
"DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;"
|
||||
"G:control->DMT/_4:control;H->I:1");
|
||||
}
|
||||
@ -2681,9 +2686,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;"
|
||||
"A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
|
||||
"C:control->DMT/_0:control;C:control->DMT/_1:control;"
|
||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||
"A:control->DMT/_0:control;A:control->DMT/_1:control;B->E:1;C->F;"
|
||||
"C:control->DMT/_2:control;C:control->DMT/_3:control;"
|
||||
"D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;"
|
||||
"DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;"
|
||||
"F:2->H:4;G->H:2;H->I:1");
|
||||
}
|
||||
@ -3060,8 +3065,8 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
|
||||
"C:control->DMT/_1:control;C:control->DMT/_2:control;"
|
||||
"C:control->DMT/_3:control;C:control->DMT/_4:control;"
|
||||
"C:control->DMT/_5:control;C:control->DMT/_6:control;"
|
||||
"D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
|
||||
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
|
||||
"D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;DMT/_3->F:3;"
|
||||
"DMT/_4->F:7;DMT/_5->F:4;DMT/_6->F:6;E->G;F->G:1");
|
||||
}
|
||||
|
||||
/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/batch_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -32,16 +33,24 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
int64 window_size = 0;
|
||||
int64 stride = 1;
|
||||
int64 stride = 0;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride));
|
||||
OP_REQUIRES(
|
||||
ctx, window_size > 0,
|
||||
errors::InvalidArgument("Window size must be greater than zero."));
|
||||
OP_REQUIRES(
|
||||
ctx, stride > 0 && stride < window_size,
|
||||
errors::InvalidArgument("Stride must be in [1, window_size)."));
|
||||
OP_REQUIRES(ctx, stride > 0,
|
||||
errors::InvalidArgument("Stride must be greater than zero."));
|
||||
if (stride == window_size) {
|
||||
LOG(WARNING) << "stride: " << stride
|
||||
<< " is equal to window_size: " << window_size
|
||||
<< ", to use `batch` instead.";
|
||||
} else if (stride > window_size) {
|
||||
LOG(WARNING) << "stride: " << stride
|
||||
<< " is greater than window_size: " << window_size
|
||||
<< ", you will lose some data.";
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, window_size, stride, input);
|
||||
}
|
||||
@ -124,12 +133,15 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
batch_elements.reserve(window_size);
|
||||
const bool first_call = cache_.empty();
|
||||
if (first_call) {
|
||||
cache_.reserve(window_size);
|
||||
} else {
|
||||
// Reuse cache in the previous iteration.
|
||||
cache_.swap(batch_elements);
|
||||
// Use cache if stride < window_size.
|
||||
if (stride < window_size) {
|
||||
const bool first_call = cache_.empty();
|
||||
if (first_call) {
|
||||
cache_.reserve(window_size);
|
||||
} else {
|
||||
// Reuse cache in the previous iteration.
|
||||
cache_.swap(batch_elements);
|
||||
}
|
||||
}
|
||||
// Fill up with new elements.
|
||||
*end_of_sequence = false;
|
||||
@ -149,9 +161,22 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
|
||||
DCHECK(*end_of_sequence);
|
||||
return Status::OK();
|
||||
}
|
||||
// Cache the data used for the next iteration.
|
||||
for (size_t i = stride; i < window_size; ++i) {
|
||||
cache_.emplace_back(batch_elements[i]);
|
||||
|
||||
if (stride < window_size) {
|
||||
// Cache the data used for the next iteration.
|
||||
for (size_t i = stride; i < window_size; ++i) {
|
||||
cache_.emplace_back(batch_elements[i]);
|
||||
}
|
||||
} else if (stride > window_size) {
|
||||
// Drop the data before the next iteration.
|
||||
std::vector<Tensor> batch_element_tuple;
|
||||
for (size_t i = window_size; i < stride && !*end_of_sequence; ++i) {
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
|
||||
end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
input_impl_.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,8 @@ namespace tensorflow {
|
||||
|
||||
#ifndef INTEL_MKL_ML
|
||||
|
||||
struct ConvFwdDimensions {
|
||||
// This structure aggregates multiple inputs to Conv2DFwd* methods.
|
||||
struct MklConvFwdParams {
|
||||
memory::dims src_dims;
|
||||
memory::dims filter_dims;
|
||||
memory::dims bias_dims;
|
||||
@ -69,48 +70,56 @@ struct ConvFwdDimensions {
|
||||
memory::dims padding_left;
|
||||
memory::dims padding_right;
|
||||
|
||||
ConvFwdDimensions(memory::dims src_dims,
|
||||
memory::dims filter_dims, memory::dims bias_dims,
|
||||
memory::dims dst_dims, memory::dims strides,
|
||||
memory::dims dilations, memory::dims padding_left,
|
||||
memory::dims padding_right) :
|
||||
src_dims(src_dims), filter_dims(filter_dims),
|
||||
bias_dims(bias_dims), dst_dims(dst_dims),
|
||||
strides(strides), dilations(dilations),
|
||||
padding_left(padding_left), padding_right(padding_right) {
|
||||
}
|
||||
MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
|
||||
memory::dims bias_dims, memory::dims dst_dims,
|
||||
memory::dims strides, memory::dims dilations,
|
||||
memory::dims padding_left, memory::dims padding_right)
|
||||
: src_dims(src_dims),
|
||||
filter_dims(filter_dims),
|
||||
bias_dims(bias_dims),
|
||||
dst_dims(dst_dims),
|
||||
strides(strides),
|
||||
dilations(dilations),
|
||||
padding_left(padding_left),
|
||||
padding_right(padding_right) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Conv2DFwd : public DnnOp {
|
||||
class MklConv2DFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit Conv2DFwd(const ConvFwdDimensions& convFwdDims) {
|
||||
fwd_stream_.reset(new stream(stream::kind::eager));
|
||||
explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
context_.fwd_stream.reset(new stream(stream::kind::eager));
|
||||
// create conv primitive
|
||||
if (conv_fwd_ == nullptr) {
|
||||
if (context_.conv_fwd == nullptr) {
|
||||
Setup(convFwdDims);
|
||||
}
|
||||
}
|
||||
|
||||
~Conv2DFwd() {}
|
||||
~MklConv2DFwdPrimitive() {}
|
||||
|
||||
// Convolution forward execute with bias
|
||||
// src_data: input data buffer of src
|
||||
// filter_data: input data buffer of filter (weights)
|
||||
// bias_data: input data buffer of bias
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(T* src_data, T* filter_data, T* bias_data, T* dst_data) {
|
||||
src_mem_->set_data_handle(static_cast<void*>(src_data));
|
||||
filter_mem_->set_data_handle(static_cast<void*>(filter_data));
|
||||
bias_mem_->set_data_handle(static_cast<void*>(bias_data));
|
||||
dst_mem_->set_data_handle(static_cast<void*>(dst_data));
|
||||
fwd_stream_->submit(fwd_primitives_);
|
||||
void Execute(const T* src_data, const T* filter_data, const T* bias_data,
|
||||
const T* dst_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(filter_data)));
|
||||
context_.bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(bias_data)));
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(dst_data)));
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
|
||||
// after exec, set data handle back
|
||||
src_mem_->set_data_handle(DummyData);
|
||||
filter_mem_->set_data_handle(DummyData);
|
||||
bias_mem_->set_data_handle(DummyData);
|
||||
dst_mem_->set_data_handle(DummyData);
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.filter_mem->set_data_handle(DummyData);
|
||||
context_.bias_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
|
||||
return;
|
||||
}
|
||||
@ -119,139 +128,177 @@ class Conv2DFwd : public DnnOp {
|
||||
// src_data: input data buffer of src
|
||||
// filter_data: input data buffer of filter (weights)
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(T* src_data, T* filter_data, T* dst_data) {
|
||||
src_mem_->set_data_handle(static_cast<void*>(src_data));
|
||||
filter_mem_->set_data_handle(static_cast<void*>(filter_data));
|
||||
dst_mem_->set_data_handle(static_cast<void*>(dst_data));
|
||||
fwd_stream_->submit(fwd_primitives_);
|
||||
void Execute(const T* src_data, const T* filter_data, const T* dst_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(filter_data)));
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(dst_data)));
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
|
||||
// after exec, set data handle back
|
||||
src_mem_->set_data_handle(DummyData);
|
||||
filter_mem_->set_data_handle(DummyData);
|
||||
dst_mem_->set_data_handle(DummyData);
|
||||
|
||||
return;
|
||||
// after execution, set data handle back
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.filter_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
}
|
||||
|
||||
// expected memory format for this primitive instance
|
||||
memory::format src_fmt_;
|
||||
memory::format filter_fmt_;
|
||||
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
|
||||
// convolution primitive
|
||||
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd_;
|
||||
std::shared_ptr<mkldnn::primitive> conv_fwd_;
|
||||
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
|
||||
|
||||
std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
|
||||
GetPrimitiveDesc() const {
|
||||
return context_.fwd_pd;
|
||||
}
|
||||
|
||||
private:
|
||||
void Setup(const ConvFwdDimensions& convFwdDims) {
|
||||
// Primitive reuse context for Conv2D Fwd op
|
||||
struct ConvFwdContext {
|
||||
// expected memory format for this primitive instance
|
||||
memory::format src_fmt;
|
||||
memory::format filter_fmt;
|
||||
|
||||
// MKLDNN memory
|
||||
std::shared_ptr<mkldnn::memory> src_mem;
|
||||
std::shared_ptr<mkldnn::memory> filter_mem;
|
||||
std::shared_ptr<mkldnn::memory> bias_mem;
|
||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
// desc & prmitive desc
|
||||
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
|
||||
|
||||
// memory desc
|
||||
std::shared_ptr<mkldnn::memory::desc> src_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> filter_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> bias_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> dst_md;
|
||||
|
||||
// convolution primitive
|
||||
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
|
||||
std::shared_ptr<mkldnn::primitive> conv_fwd;
|
||||
|
||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
ConvFwdContext()
|
||||
: src_fmt(memory::format::any),
|
||||
filter_fmt(memory::format::any),
|
||||
src_mem(nullptr),
|
||||
filter_mem(nullptr),
|
||||
bias_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
fwd_desc(nullptr),
|
||||
src_md(nullptr),
|
||||
filter_md(nullptr),
|
||||
bias_md(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
conv_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
};
|
||||
|
||||
void Setup(const MklConvFwdParams& convFwdDims) {
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
src_md_.reset(new memory::desc({convFwdDims.src_dims},
|
||||
MklDnnType<T>(), memory::format::any));
|
||||
context_.src_md.reset(new memory::desc(
|
||||
{convFwdDims.src_dims}, MklDnnType<T>(), memory::format::any));
|
||||
|
||||
filter_md_.reset(new memory::desc({convFwdDims.filter_dims},
|
||||
MklDnnType<T>(), memory::format::any));
|
||||
context_.filter_md.reset(new memory::desc(
|
||||
{convFwdDims.filter_dims}, MklDnnType<T>(), memory::format::any));
|
||||
|
||||
dst_md_.reset(new memory::desc({convFwdDims.dst_dims},
|
||||
MklDnnType<T>(), memory::format::any));
|
||||
context_.dst_md.reset(new memory::desc(
|
||||
{convFwdDims.dst_dims}, MklDnnType<T>(), memory::format::any));
|
||||
|
||||
if (!convFwdDims.bias_dims.empty())
|
||||
bias_md_.reset(new memory::desc({convFwdDims.bias_dims},
|
||||
MklDnnType<T>(), memory::format::any));
|
||||
context_.bias_md.reset(new memory::desc(
|
||||
{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::any));
|
||||
|
||||
// create a convolution
|
||||
if (!convFwdDims.bias_dims.empty()) {
|
||||
fwd_desc_.reset(new convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, *src_md_, *filter_md_, *bias_md_, *dst_md_,
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.bias_md, *context_.dst_md,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
} else {
|
||||
fwd_desc_.reset(new convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, *src_md_, *filter_md_, *dst_md_,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
|
||||
convFwdDims.dilations, convFwdDims.padding_left,
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
}
|
||||
|
||||
fwd_pd_.reset(new convolution_forward::primitive_desc(
|
||||
*fwd_desc_, cpu_engine_));
|
||||
context_.fwd_pd.reset(new convolution_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
|
||||
// store the expected memory format
|
||||
src_fmt_ = static_cast<mkldnn::memory::format>(
|
||||
fwd_pd_.get()->src_primitive_desc().desc().data.format);
|
||||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
|
||||
|
||||
filter_fmt_ = static_cast<mkldnn::memory::format>(
|
||||
fwd_pd_.get()->weights_primitive_desc().desc().data.format);
|
||||
context_.filter_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
|
||||
|
||||
// create memory primitive based on dummy data
|
||||
src_mem_.reset(new memory(fwd_pd_.get()->src_primitive_desc(), DummyData));
|
||||
filter_mem_.reset(new memory(fwd_pd_.get()->weights_primitive_desc(),
|
||||
DummyData));
|
||||
dst_mem_.reset(new memory(fwd_pd_.get()->dst_primitive_desc(), DummyData));
|
||||
context_.src_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
|
||||
context_.filter_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
if (!convFwdDims.bias_dims.empty()) {
|
||||
bias_mem_.reset(new memory({{{convFwdDims.bias_dims}, MklDnnType<T>(),
|
||||
memory::format::x}, cpu_engine_}, DummyData));
|
||||
conv_fwd_.reset(new convolution_forward(*fwd_pd_, *src_mem_,
|
||||
*filter_mem_, *bias_mem_, *dst_mem_));
|
||||
context_.bias_mem.reset(new memory(
|
||||
{{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x},
|
||||
cpu_engine_},
|
||||
DummyData));
|
||||
context_.conv_fwd.reset(new convolution_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
|
||||
*context_.bias_mem, *context_.dst_mem));
|
||||
} else {
|
||||
conv_fwd_.reset(new convolution_forward(*fwd_pd_, *src_mem_,
|
||||
*filter_mem_, *dst_mem_));
|
||||
context_.conv_fwd.reset(
|
||||
new convolution_forward(*context_.fwd_pd, *context_.src_mem,
|
||||
*context_.filter_mem, *context_.dst_mem));
|
||||
}
|
||||
|
||||
fwd_primitives_.push_back(*conv_fwd_);
|
||||
context_.fwd_primitives.push_back(*context_.conv_fwd);
|
||||
return;
|
||||
}
|
||||
|
||||
// MKLDNN memory
|
||||
std::shared_ptr<mkldnn::memory> src_mem_;
|
||||
std::shared_ptr<mkldnn::memory> filter_mem_;
|
||||
std::shared_ptr<mkldnn::memory> bias_mem_;
|
||||
std::shared_ptr<mkldnn::memory> dst_mem_;
|
||||
|
||||
std::shared_ptr<mkldnn::stream> fwd_stream_;
|
||||
std::vector<mkldnn::primitive> fwd_primitives_;
|
||||
|
||||
// desc & prmitive desc
|
||||
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc_;
|
||||
|
||||
// memory desc
|
||||
std::shared_ptr<mkldnn::memory::desc> src_md_;
|
||||
std::shared_ptr<mkldnn::memory::desc> filter_md_;
|
||||
std::shared_ptr<mkldnn::memory::desc> bias_md_;
|
||||
std::shared_ptr<mkldnn::memory::desc> dst_md_;
|
||||
|
||||
engine cpu_engine_ = engine(engine::cpu, 0);
|
||||
struct ConvFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Conv2DFwdFactory : public DnnOpFactory<T> {
|
||||
class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
public:
|
||||
static Conv2DFwd<T>* Get(const ConvFwdDimensions& convFwdDims) {
|
||||
Conv2DFwd<T>* conv2d_fwd = nullptr;
|
||||
static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
|
||||
MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
|
||||
|
||||
// try to find a suitable one in pool
|
||||
conv2d_fwd = dynamic_cast<Conv2DFwd<T>*> (
|
||||
Conv2DFwdFactory<T>::GetInstance().GetConv2DFwd(convFwdDims));
|
||||
// try to find a suitable one in pool
|
||||
conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
|
||||
MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
|
||||
convFwdDims));
|
||||
|
||||
if (conv2d_fwd == nullptr) {
|
||||
conv2d_fwd = new Conv2DFwd<T>(convFwdDims);
|
||||
Conv2DFwdFactory<T>::GetInstance().SetConv2DFwd(
|
||||
convFwdDims, conv2d_fwd);
|
||||
}
|
||||
return conv2d_fwd;
|
||||
if (conv2d_fwd == nullptr) {
|
||||
conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
|
||||
MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
|
||||
conv2d_fwd);
|
||||
}
|
||||
return conv2d_fwd;
|
||||
}
|
||||
|
||||
private:
|
||||
Conv2DFwdFactory() {}
|
||||
~Conv2DFwdFactory() {}
|
||||
MklConv2DFwdPrimitiveFactory() {}
|
||||
~MklConv2DFwdPrimitiveFactory() {}
|
||||
|
||||
static const int kDilationH = 0, kDilationW = 1;
|
||||
|
||||
static Conv2DFwdFactory& GetInstance() {
|
||||
static Conv2DFwdFactory instance_;
|
||||
static MklConv2DFwdPrimitiveFactory& GetInstance() {
|
||||
static MklConv2DFwdPrimitiveFactory instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
static std::string CreateKey(const ConvFwdDimensions& convFwdDims) {
|
||||
static std::string CreateKey(const MklConvFwdParams& convFwdDims) {
|
||||
std::string prefix = "conv2d_fwd_";
|
||||
FactoryKeyCreator key_creator;
|
||||
key_creator.AddAsKey(prefix);
|
||||
@ -266,12 +313,12 @@ class Conv2DFwdFactory : public DnnOpFactory<T> {
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
DnnOp* GetConv2DFwd(const ConvFwdDimensions& convFwdDims) {
|
||||
MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
|
||||
std::string key = CreateKey(convFwdDims);
|
||||
return this->GetOp(key);
|
||||
}
|
||||
|
||||
void SetConv2DFwd(const ConvFwdDimensions& convFwdDims, DnnOp *op) {
|
||||
void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
|
||||
std::string key = CreateKey(convFwdDims);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
@ -762,7 +809,6 @@ class MklConv2DOp : public OpKernel {
|
||||
|
||||
MklDnnData<T> src(&cpu_engine);
|
||||
MklDnnData<T> filter(&cpu_engine);
|
||||
MklDnnData<T> dst(&cpu_engine); // output
|
||||
|
||||
memory::dims src_dims, filter_dims, padding_left, padding_right,
|
||||
dilations, strides;
|
||||
@ -812,7 +858,6 @@ class MklConv2DOp : public OpKernel {
|
||||
auto src_md = src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<T>(), tf_fmt);
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
|
||||
// Although filter shape (filter_dims) required is in MKL-DNN order,
|
||||
// the layout is Tensorflow's layout (HWIO).
|
||||
@ -820,29 +865,30 @@ class MklConv2DOp : public OpKernel {
|
||||
? filter_mkl_shape.GetMklLayout()
|
||||
: memory::desc(filter_dims, MklDnnType<T>(),
|
||||
memory::format::hwio);
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
|
||||
// MKLDNN dilation starts from 0.
|
||||
dilations[kDilationH] -= 1;
|
||||
dilations[kDilationW] -= 1;
|
||||
|
||||
// get a conv2d fwd from primitive pool
|
||||
Conv2DFwd<T> *conv2d_fwd = nullptr;
|
||||
MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
|
||||
if (biasEnabled) {
|
||||
memory::dims bias_dims = {};
|
||||
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
|
||||
ConvFwdDimensions convFwdDims(src_dims, filter_dims, bias_dims,
|
||||
dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
|
||||
conv2d_fwd = Conv2DFwdFactory<T>::Get(convFwdDims);
|
||||
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
|
||||
dst_dims_mkl_order, strides, dilations,
|
||||
padding_left, padding_right);
|
||||
conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
|
||||
} else {
|
||||
ConvFwdDimensions convFwdDims(src_dims, filter_dims, NONE_DIMS,
|
||||
dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
|
||||
conv2d_fwd = Conv2DFwdFactory<T>::Get(convFwdDims);
|
||||
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
|
||||
dst_dims_mkl_order, strides, dilations,
|
||||
padding_left, padding_right);
|
||||
conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
|
||||
}
|
||||
|
||||
// allocate output tensors output_tensor and filter_out_tensor
|
||||
std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
|
||||
conv_fwd_pd = conv2d_fwd->fwd_pd_;
|
||||
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
|
||||
conv2d_fwd->GetPrimitiveDesc();
|
||||
AllocateOutputTensor(context, *conv_fwd_pd,
|
||||
dst_dims_mkl_order, tf_fmt, &dst_tensor);
|
||||
Tensor* filter_out_tensor = nullptr;
|
||||
@ -854,20 +900,28 @@ class MklConv2DOp : public OpKernel {
|
||||
|
||||
// check whether src/filter need reorder
|
||||
std::vector<primitive> net;
|
||||
if (src_md.data.format != conv2d_fwd->src_fmt_)
|
||||
src.CheckReorderToOpMem(
|
||||
conv_fwd_pd.get()->src_primitive_desc(), &net);
|
||||
T* src_data = nullptr;
|
||||
if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc(), &net);
|
||||
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
|
||||
}
|
||||
T* filter_data = nullptr;
|
||||
if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
|
||||
filter.GetTensorBuffer(filter_out_tensor),
|
||||
&net);
|
||||
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
filter_data =
|
||||
static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
|
||||
}
|
||||
|
||||
if (filter_md.data.format != conv2d_fwd->filter_fmt_)
|
||||
filter.CheckReorderToOpMem(
|
||||
conv_fwd_pd.get()->weights_primitive_desc(),
|
||||
filter.GetTensorBuffer(filter_out_tensor), &net);
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
|
||||
T* src_data = static_cast<T*>(
|
||||
src.GetOpMem().get_data_handle());
|
||||
T* filter_data = static_cast<T*>(
|
||||
filter.GetOpMem().get_data_handle());
|
||||
|
||||
// execute convolution
|
||||
if (biasEnabled) {
|
||||
|
@ -295,7 +295,7 @@ __global__ void ColumnReduceMax16ColumnsKernel(
|
||||
|
||||
// 1D array necessary due to bug in CUDA 9 compiler.
|
||||
// TODO(nluehr) revert to 2D array when compiler is ready.
|
||||
// This is the mimic the following, but without any constructors:
|
||||
// This is to mimic the following, but without any constructors:
|
||||
// __shared__ storage_type<value_type> partial_sums[32 * 33];
|
||||
__shared__ __align__(
|
||||
alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)];
|
||||
|
@ -16,6 +16,12 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
|
||||
|
||||
// This file requires the following include because it uses CudaAtomicMax:
|
||||
// #include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
// Unfortunately we can't add the #include, since it breaks compilation for
|
||||
// non-GPU targets. This only breaks in clang, because it's more strict for
|
||||
// template code and CudaAtomicMax is used in template context.
|
||||
|
||||
// This file requires the following include because it uses CudaAtomicMax:
|
||||
// #include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
@ -614,7 +614,13 @@ REGISTER_OP("ApproximateEqual")
|
||||
.SetIsCommutative()
|
||||
.Attr("T: numbertype")
|
||||
.Attr("tolerance: float = 0.00001")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// The inputs 'x' and 'y' must have the same shape.
|
||||
ShapeHandle data_x = c->input(0);
|
||||
ShapeHandle data_y = c->input(1);
|
||||
TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
|
||||
return shape_inference::UnchangedShape(c);
|
||||
});
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
|
@ -137,8 +137,8 @@ Status EncodeJwtClaim(StringPiece client_email, StringPiece scope,
|
||||
const auto expiration_timestamp_sec =
|
||||
request_timestamp_sec + kRequestedTokenLifetimeSec;
|
||||
|
||||
root["iat"] = request_timestamp_sec;
|
||||
root["exp"] = expiration_timestamp_sec;
|
||||
root["iat"] = Json::Value::UInt64(request_timestamp_sec);
|
||||
root["exp"] = Json::Value::UInt64(expiration_timestamp_sec);
|
||||
|
||||
// Step 2: represent the JSON as a string.
|
||||
string claim = root.toStyledString();
|
||||
|
@ -202,7 +202,10 @@ def cc_proto_library(
|
||||
)
|
||||
|
||||
if use_grpc_plugin:
|
||||
cc_libs += ["//external:grpc_lib"]
|
||||
cc_libs += select({
|
||||
"//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
|
||||
"//conditions:default": ["//external:grpc_lib"],
|
||||
})
|
||||
|
||||
if default_header:
|
||||
header_only_name = name
|
||||
|
@ -171,5 +171,10 @@ int64 AvailableRam() {
|
||||
return INT64_MAX;
|
||||
}
|
||||
|
||||
int NumHyperthreadsPerCore() {
|
||||
static const int ht_per_core = tensorflow::port::CPUIDNumSMT();
|
||||
return (ht_per_core > 0) ? ht_per_core : 1;
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace tensorflow
|
||||
|
@ -47,9 +47,9 @@ Json::Value ChromeTraceFormatter::CreateEvent(const string& ph,
|
||||
event["ph"] = Json::Value(ph);
|
||||
event["cat"] = Json::Value(category);
|
||||
event["name"] = Json::Value(name);
|
||||
event["pid"] = Json::Value(pid);
|
||||
event["tid"] = Json::Value(tid);
|
||||
event["ts"] = Json::Value(ts);
|
||||
event["pid"] = Json::Int64(pid);
|
||||
event["tid"] = Json::Int64(tid);
|
||||
event["ts"] = Json::Int64(ts);
|
||||
return event;
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ void ChromeTraceFormatter::EmitPID(const string& name, int64 pid) {
|
||||
Json::Value event(Json::objectValue);
|
||||
event["name"] = Json::Value("process_name");
|
||||
event["ph"] = Json::Value("M");
|
||||
event["pid"] = Json::Value(pid);
|
||||
event["pid"] = Json::Int64(pid);
|
||||
Json::Value args(Json::objectValue);
|
||||
args["name"] = Json::Value(name);
|
||||
event["args"] = args;
|
||||
@ -68,7 +68,7 @@ void ChromeTraceFormatter::EmitRegion(int64 ts, int64 duration, int64 pid,
|
||||
int64 tid, const string& category,
|
||||
const string& name, Json::Value args) {
|
||||
Json::Value event = CreateEvent("X", category, name, pid, tid, ts);
|
||||
event["dur"] = Json::Value(duration);
|
||||
event["dur"] = Json::Int64(duration);
|
||||
event["args"] = std::move(args);
|
||||
metadata_.push_back(event);
|
||||
}
|
||||
@ -76,14 +76,14 @@ void ChromeTraceFormatter::EmitRegion(int64 ts, int64 duration, int64 pid,
|
||||
void ChromeTraceFormatter::EmitFlowStart(const string& name, int64 ts,
|
||||
int64 pid, int64 tid, int64 flow_id) {
|
||||
Json::Value event = CreateEvent("s", "DataFlow", name, pid, tid, ts);
|
||||
event["id"] = flow_id;
|
||||
event["id"] = Json::Int64(flow_id);
|
||||
events_.push_back(event);
|
||||
}
|
||||
|
||||
void ChromeTraceFormatter::EmitFlowEnd(const string& name, int64 ts, int64 pid,
|
||||
int64 tid, int64 flow_id) {
|
||||
Json::Value event = CreateEvent("t", "DataFlow", name, pid, tid, ts);
|
||||
event["id"] = flow_id;
|
||||
event["id"] = Json::Int64(flow_id);
|
||||
events_.push_back(event);
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ void ChromeTraceFormatter::EmitCounter(
|
||||
const std::map<int64, std::vector<string>>& tensor_mem) {
|
||||
Json::Value event = CreateEvent("C", category, "Allocated Bytes", pid, 0, ts);
|
||||
Json::Value args(Json::objectValue);
|
||||
args["Allocator Bytes in Use"] = Json::Value(bytes);
|
||||
args["Allocator Bytes in Use"] = Json::Int64(bytes);
|
||||
event["args"] = args;
|
||||
events_.push_back(event);
|
||||
|
||||
|
@ -1814,11 +1814,11 @@ class MklDnnData {
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for operations with reuse of DNN primitives
|
||||
/// Base class for operations with reuse of primitives
|
||||
///
|
||||
class DnnOp {
|
||||
class MklPrimitive {
|
||||
public:
|
||||
virtual ~DnnOp() {}
|
||||
virtual ~MklPrimitive() {}
|
||||
|
||||
// Dummy data. Its size, hard-coded as 256 here, does
|
||||
// not matter since MKL should never operate on this buffer.
|
||||
@ -1826,33 +1826,33 @@ class DnnOp {
|
||||
};
|
||||
|
||||
const mkldnn::memory::dims NONE_DIMS = {};
|
||||
// This constant is used to declare dummy buffer (size), for MKL primitives
|
||||
template <typename T>
|
||||
class DnnOpFactory {
|
||||
public:
|
||||
DnnOpFactory() {}
|
||||
~DnnOpFactory() {}
|
||||
|
||||
DnnOp* GetOp(const std::string& key) {
|
||||
auto stream_iter = DnnOpFactory<T>::GetHashMap().find(key);
|
||||
if (stream_iter == DnnOpFactory<T>::GetHashMap().end()) {
|
||||
template <typename T>
|
||||
class MklPrimitiveFactory {
|
||||
public:
|
||||
MklPrimitiveFactory() {}
|
||||
~MklPrimitiveFactory() {}
|
||||
|
||||
MklPrimitive* GetOp(const std::string& key) {
|
||||
auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
|
||||
if (stream_iter == MklPrimitiveFactory<T>::GetHashMap().end()) {
|
||||
return nullptr;
|
||||
} else {
|
||||
return stream_iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
void SetOp(const std::string& key, DnnOp* op) {
|
||||
auto stream_iter = DnnOpFactory<T>::GetHashMap().find(key);
|
||||
void SetOp(const std::string& key, MklPrimitive* op) {
|
||||
auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
|
||||
|
||||
CHECK(stream_iter == DnnOpFactory<T>::GetHashMap().end());
|
||||
CHECK(stream_iter == MklPrimitiveFactory<T>::GetHashMap().end());
|
||||
|
||||
DnnOpFactory<T>::GetHashMap()[key] = op;
|
||||
MklPrimitiveFactory<T>::GetHashMap()[key] = op;
|
||||
}
|
||||
|
||||
private:
|
||||
static inline std::unordered_map<std::string, DnnOp*> &GetHashMap() {
|
||||
static thread_local std::unordered_map<std::string, DnnOp*> map_;
|
||||
static inline std::unordered_map<std::string, MklPrimitive*>& GetHashMap() {
|
||||
static thread_local std::unordered_map<std::string, MklPrimitive*> map_;
|
||||
return map_;
|
||||
}
|
||||
};
|
||||
|
29
tensorflow/docs_src/get_started/index.md
Normal file
29
tensorflow/docs_src/get_started/index.md
Normal file
@ -0,0 +1,29 @@
|
||||
# Get Started
|
||||
|
||||
If you are new to machine learning, we recommend taking the following online
|
||||
course prior to diving into TensorFlow documentation:
|
||||
|
||||
* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/),
|
||||
which introduces machine learning concepts and encourages experimentation
|
||||
with existing TensorFlow code.
|
||||
|
||||
TensorFlow is a tool for machine learning. While it contains a wide range of
|
||||
functionality, TensorFlow is mainly designed for deep neural network models.
|
||||
|
||||
The easiest way to get started with TensorFlow is by using Eager Execution.
|
||||
|
||||
* @{$get_started/eager}, is for anyone new to machine learning or TensorFlow.
|
||||
|
||||
TensorFlow provides many APIs. The remainder of this section focuses on the
|
||||
Estimator API which provide scalable, high-performance models. See the
|
||||
@{$estimators} guide.
|
||||
|
||||
For more advanced users:
|
||||
|
||||
* The @{$low_level_intro$Low Level Introduction} demonstrates how to use
|
||||
TensorFlow outside of the Estimator framework, for debugging and
|
||||
experimentation.
|
||||
* The @{$guide$Programmer's Guide} details major
|
||||
TensorFlow components.
|
||||
* The @{$tutorials$Tutorials} provide walkthroughs of a variety of
|
||||
TensorFlow models.
|
@ -17,7 +17,7 @@ how to use the graphical user interface (GUI) of tfdbg, i.e., the
|
||||
Note: The TensorFlow debugger uses a
|
||||
[curses](https://en.wikipedia.org/wiki/Curses_\(programming_library\))-based text
|
||||
user interface. On Mac OS X, the `ncurses` library is required and can be
|
||||
installed with `brew install homebrew/dupes/ncurses`. On Windows, curses isn't as
|
||||
installed with `brew install ncurses`. On Windows, curses isn't as
|
||||
well supported, so a [readline](https://en.wikipedia.org/wiki/GNU_Readline)-based
|
||||
interface can be used with tfdbg by installing `pyreadline` with `pip`. If you
|
||||
use Anaconda3, you can install it with a command such as
|
||||
|
245
tensorflow/go/attrs.go
Normal file
245
tensorflow/go/attrs.go
Normal file
@ -0,0 +1,245 @@
|
||||
/*
|
||||
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package tensorflow
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "tensorflow/c/c_api.h"
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// makeCShape converts a shape specified in C.int64_t into a Shape.
|
||||
func makeCShape(shape []C.int64_t) Shape {
|
||||
s := Shape{dims: make([]int64, len(shape))}
|
||||
for i, n := range shape {
|
||||
s.dims[i] = int64(n)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Attr returns the value of an attribute on op. It returns an error if the
|
||||
// attribute does not exist.
|
||||
func (op *Operation) Attr(name string) (interface{}, error) {
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
status := newStatus()
|
||||
meta := C.TF_OperationGetAttrMetadata(op.c, cname, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if meta.is_list == 1 {
|
||||
return listAttribute(op, cname, meta)
|
||||
}
|
||||
return scalarAttribute(op, cname, meta)
|
||||
}
|
||||
|
||||
func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) {
|
||||
status := newStatus()
|
||||
|
||||
switch meta._type {
|
||||
case C.TF_ATTR_STRING:
|
||||
if meta.list_size == 0 {
|
||||
return []string(nil), nil
|
||||
}
|
||||
values := make([]unsafe.Pointer, meta.list_size)
|
||||
lengths := make([]C.size_t, meta.list_size)
|
||||
// Add one element in case total_size is zero.
|
||||
storage := make([]C.char, meta.total_size+1)
|
||||
C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list := make([]string, meta.list_size)
|
||||
for i, val := range values {
|
||||
length := lengths[i]
|
||||
list[i] = C.GoStringN((*C.char)(val), C.int(length))
|
||||
}
|
||||
return list, nil
|
||||
|
||||
case C.TF_ATTR_INT:
|
||||
if meta.list_size == 0 {
|
||||
return []int64(nil), nil
|
||||
}
|
||||
list := make([]C.int64_t, meta.list_size)
|
||||
C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vals := make([]int64, meta.list_size)
|
||||
for i, val := range list {
|
||||
vals[i] = int64(val)
|
||||
}
|
||||
return vals, nil
|
||||
|
||||
case C.TF_ATTR_FLOAT:
|
||||
if meta.list_size == 0 {
|
||||
return []float32(nil), nil
|
||||
}
|
||||
list := make([]C.float, meta.list_size)
|
||||
C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vals := make([]float32, meta.list_size)
|
||||
for i, val := range list {
|
||||
vals[i] = float32(val)
|
||||
}
|
||||
return vals, nil
|
||||
|
||||
case C.TF_ATTR_BOOL:
|
||||
if meta.list_size == 0 {
|
||||
return []bool(nil), nil
|
||||
}
|
||||
list := make([]C.uchar, meta.list_size)
|
||||
C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vals := make([]bool, meta.list_size)
|
||||
for i, val := range list {
|
||||
vals[i] = val == 1
|
||||
}
|
||||
return vals, nil
|
||||
|
||||
case C.TF_ATTR_TYPE:
|
||||
if meta.list_size == 0 {
|
||||
return []DataType(nil), nil
|
||||
}
|
||||
list := make([]C.TF_DataType, meta.list_size)
|
||||
C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vals := make([]DataType, meta.list_size)
|
||||
for i, val := range list {
|
||||
vals[i] = DataType(val)
|
||||
}
|
||||
return vals, nil
|
||||
|
||||
case C.TF_ATTR_TENSOR:
|
||||
if meta.list_size == 0 {
|
||||
return []*Tensor(nil), nil
|
||||
}
|
||||
list := make([]*C.TF_Tensor, meta.list_size)
|
||||
C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vals := make([]*Tensor, meta.list_size)
|
||||
for i, t := range list {
|
||||
vals[i] = newTensorFromC(t)
|
||||
}
|
||||
return vals, nil
|
||||
|
||||
case C.TF_ATTR_SHAPE:
|
||||
if meta.list_size == 0 {
|
||||
return []Shape(nil), nil
|
||||
}
|
||||
dims := make([]*C.int64_t, meta.list_size)
|
||||
numDims := make([]C.int, meta.list_size)
|
||||
// Add one element in case total_size is zero.
|
||||
storage := make([]C.int64_t, meta.total_size+1)
|
||||
C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list := make([]Shape, meta.list_size)
|
||||
for i, dim := range dims {
|
||||
numDim := numDims[i]
|
||||
// If the number of dimensions is unknown, default to empty shape.
|
||||
if numDim < 0 {
|
||||
continue
|
||||
}
|
||||
// A []C.int64_t slice backed by C memory.
|
||||
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
slice := (*[1 << 30]C.int64_t)(unsafe.Pointer(dim))[:numDim:numDim]
|
||||
list[i] = makeCShape(slice)
|
||||
}
|
||||
return list, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("list type %v not supported", meta._type)
|
||||
}
|
||||
}
|
||||
|
||||
func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) {
|
||||
status := newStatus()
|
||||
|
||||
switch meta._type {
|
||||
case C.TF_ATTR_STRING:
|
||||
if meta.total_size == 0 {
|
||||
return "", nil
|
||||
}
|
||||
v := make([]C.char, meta.total_size)
|
||||
C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return C.GoStringN(&v[0], C.int(meta.total_size)), nil
|
||||
|
||||
case C.TF_ATTR_INT:
|
||||
var v C.int64_t
|
||||
C.TF_OperationGetAttrInt(op.c, cname, &v, status.c)
|
||||
return int64(v), status.Err()
|
||||
|
||||
case C.TF_ATTR_FLOAT:
|
||||
var v C.float
|
||||
C.TF_OperationGetAttrFloat(op.c, cname, &v, status.c)
|
||||
return float32(v), status.Err()
|
||||
|
||||
case C.TF_ATTR_BOOL:
|
||||
var v C.uchar
|
||||
C.TF_OperationGetAttrBool(op.c, cname, &v, status.c)
|
||||
return v == 1, status.Err()
|
||||
|
||||
case C.TF_ATTR_TYPE:
|
||||
var v C.TF_DataType
|
||||
C.TF_OperationGetAttrType(op.c, cname, &v, status.c)
|
||||
return DataType(v), status.Err()
|
||||
|
||||
case C.TF_ATTR_TENSOR:
|
||||
var v *C.TF_Tensor
|
||||
C.TF_OperationGetAttrTensor(op.c, cname, &v, status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newTensorFromC(v), nil
|
||||
|
||||
case C.TF_ATTR_SHAPE:
|
||||
numDims := meta.total_size
|
||||
// If number of dims is unknown return empty shape to indicate that.
|
||||
if numDims < 0 {
|
||||
return Shape{}, nil
|
||||
}
|
||||
if numDims == 0 {
|
||||
return ScalarShape(), nil
|
||||
}
|
||||
dims := make([]C.int64_t, numDims)
|
||||
C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c)
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return makeCShape(dims), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("type %v not supported", meta._type)
|
||||
}
|
||||
}
|
193
tensorflow/go/attrs_test.go
Normal file
193
tensorflow/go/attrs_test.go
Normal file
@ -0,0 +1,193 @@
|
||||
/*
|
||||
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOperationAttrs(t *testing.T) {
|
||||
g := NewGraph()
|
||||
|
||||
i := 0
|
||||
makeConst := func(v interface{}) Output {
|
||||
op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v)
|
||||
i++
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return op
|
||||
}
|
||||
|
||||
makeTensor := func(v interface{}) *Tensor {
|
||||
tensor, err := NewTensor(v)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return tensor
|
||||
}
|
||||
|
||||
cases := []OpSpec{
|
||||
{
|
||||
Name: "type",
|
||||
Type: "Placeholder",
|
||||
Attrs: map[string]interface{}{
|
||||
"dtype": Float,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list(float)",
|
||||
Type: "Bucketize",
|
||||
Input: []Input{
|
||||
makeConst([]float32{1, 2, 3, 4}),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"boundaries": []float32{0, 1, 2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list(float) empty",
|
||||
Type: "Bucketize",
|
||||
Input: []Input{
|
||||
makeConst([]float32{}),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"boundaries": []float32(nil),
|
||||
},
|
||||
},
|
||||
/* TODO(ashankar): debug this issue and add it back later.
|
||||
{
|
||||
Name: "list(type),list(shape)",
|
||||
Type: "InfeedEnqueueTuple",
|
||||
Input: []Input{
|
||||
OutputList([]Output{
|
||||
makeConst(float32(1)),
|
||||
makeConst([][]int32{{2}}),
|
||||
}),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"dtypes": []DataType{Float, Int32},
|
||||
"shapes": []Shape{ScalarShape(), MakeShape(1, 1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list(type),list(shape) empty",
|
||||
Type: "InfeedEnqueueTuple",
|
||||
Input: []Input{
|
||||
OutputList([]Output{
|
||||
makeConst([][]int32{{2}}),
|
||||
}),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"dtypes": []DataType{Int32},
|
||||
"shapes": []Shape(nil),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list(type) empty,string empty,int",
|
||||
Type: "_XlaSendFromHost",
|
||||
Input: []Input{
|
||||
OutputList([]Output{}),
|
||||
makeConst(""),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"Tinputs": []DataType(nil),
|
||||
"key": "",
|
||||
"device_ordinal": int64(0),
|
||||
},
|
||||
},
|
||||
*/
|
||||
{
|
||||
Name: "list(int),int",
|
||||
Type: "StringToHashBucketStrong",
|
||||
Input: []Input{
|
||||
makeConst(""),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"num_buckets": int64(2),
|
||||
"key": []int64{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list(int) empty,int",
|
||||
Type: "StringToHashBucketStrong",
|
||||
Input: []Input{
|
||||
makeConst(""),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"num_buckets": int64(2),
|
||||
"key": ([]int64)(nil),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list(string),type",
|
||||
Type: "TensorSummary",
|
||||
Input: []Input{
|
||||
makeConst(""),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"T": String,
|
||||
"labels": []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list(string) empty,type",
|
||||
Type: "TensorSummary",
|
||||
Input: []Input{
|
||||
makeConst(""),
|
||||
},
|
||||
Attrs: map[string]interface{}{
|
||||
"T": String,
|
||||
"labels": ([]string)(nil),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "tensor",
|
||||
Type: "Const",
|
||||
Attrs: map[string]interface{}{
|
||||
"dtype": String,
|
||||
"value": makeTensor("foo"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, spec := range cases {
|
||||
op, err := g.AddOperation(spec)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for key, want := range spec.Attrs {
|
||||
out, err := op.Attr(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(out, want) {
|
||||
t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want)
|
||||
}
|
||||
wantT, ok := want.(*Tensor)
|
||||
if ok {
|
||||
wantVal := wantT.Value()
|
||||
outVal := out.(*Tensor).Value()
|
||||
if !reflect.DeepEqual(outVal, wantVal) {
|
||||
t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -11210,7 +11210,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted
|
||||
// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
|
||||
//
|
||||
// value: The cropped area of the image must contain a fraction of the
|
||||
// supplied image within in this range.
|
||||
// supplied image within this range.
|
||||
// If not specified, defaults to <f:0.05 f:1 >
|
||||
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
|
||||
return func(m optionalAttr) {
|
||||
@ -17969,9 +17969,10 @@ func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_val
|
||||
}
|
||||
|
||||
// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
|
||||
//
|
||||
// if < 0, `scale * features` otherwise.
|
||||
//
|
||||
// Assumes weights to have zero mean and variance 1.0 / fan_in.
|
||||
//
|
||||
// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
|
||||
func Selu(scope *Scope, features tf.Output) (activations tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
@ -21655,7 +21656,7 @@ func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
|
||||
// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
|
||||
//
|
||||
// The `bad_color` argument is the color to use in the generated images for
|
||||
// non-finite input values. It is a `unit8` 1-D tensor of length `channels`.
|
||||
// non-finite input values. It is a `uint8` 1-D tensor of length `channels`.
|
||||
// Each element must be in the range `[0, 255]` (It represents the value of a
|
||||
// pixel in the output image). Non-finite values in the input tensor are
|
||||
// replaced by this tensor in the output image. The default value is the color
|
||||
@ -24048,7 +24049,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort
|
||||
// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value.
|
||||
//
|
||||
// value: The cropped area of the image must contain a fraction of the
|
||||
// supplied image within in this range.
|
||||
// supplied image within this range.
|
||||
// If not specified, defaults to <f:0.05 f:1 >
|
||||
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
|
@ -65,6 +65,11 @@ func (op *Operation) Output(i int) Output {
|
||||
return Output{op, i}
|
||||
}
|
||||
|
||||
// NumInputs returns the number of inputs of op.
|
||||
func (op *Operation) NumInputs() int {
|
||||
return int(C.TF_OperationNumInputs(op.c))
|
||||
}
|
||||
|
||||
// Output represents one of the outputs of an operation in the graph. Has a
|
||||
// DataType (and eventually a Shape). May be passed as an input argument to a
|
||||
// function for adding operations to a graph, or to a Session's Run() method to
|
||||
@ -123,6 +128,67 @@ func (p Output) c() C.TF_Output {
|
||||
|
||||
func (p Output) canBeAnInput() {}
|
||||
|
||||
// Consumers returns the inputs that consume this output.
|
||||
func (p Output) Consumers() []Consumer {
|
||||
max := int(C.TF_OperationOutputNumConsumers(p.c()))
|
||||
if max == 0 {
|
||||
return nil
|
||||
}
|
||||
inputs := make([]C.TF_Input, max)
|
||||
n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max))
|
||||
inputs = inputs[:int(n)]
|
||||
|
||||
var consumers []Consumer
|
||||
for _, consumer := range inputs {
|
||||
consumers = append(consumers, Consumer{
|
||||
Index: int(consumer.index),
|
||||
Op: &Operation{
|
||||
c: consumer.oper,
|
||||
g: p.Op.g,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return consumers
|
||||
}
|
||||
|
||||
// Consumer identifies a specific input of an operation that consumes the output
|
||||
// of another operation.
|
||||
type Consumer struct {
|
||||
// Op is the Operation that is consuming the output of another operation.
|
||||
Op *Operation
|
||||
|
||||
// Index is the index of the input within Op that the output of another
|
||||
// operation is connected to.
|
||||
Index int
|
||||
}
|
||||
|
||||
func (p Consumer) c() C.TF_Input {
|
||||
if p.Op == nil {
|
||||
// Attempt to provide a more useful panic message than "nil
|
||||
// pointer dereference".
|
||||
panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers")
|
||||
}
|
||||
return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)}
|
||||
}
|
||||
|
||||
// DataType returns the type of the input.
|
||||
func (p Consumer) DataType() DataType {
|
||||
return DataType(C.TF_OperationInputType(p.c()))
|
||||
}
|
||||
|
||||
// Producer returns the Output that is connected to this Consumer.
|
||||
func (p Consumer) Producer() Output {
|
||||
output := C.TF_OperationInput(p.c())
|
||||
return Output{
|
||||
Op: &Operation{
|
||||
c: output.oper,
|
||||
g: p.Op.g,
|
||||
},
|
||||
Index: int(output.index),
|
||||
}
|
||||
}
|
||||
|
||||
// Input is the interface for specifying inputs to an operation being added to
|
||||
// a Graph.
|
||||
//
|
||||
|
@ -166,6 +166,68 @@ func TestOutputDataTypeAndShape(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationInputs(t *testing.T) {
|
||||
g := NewGraph()
|
||||
x, err := Placeholder(g, "x", Float)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
y, err := Placeholder(g, "y", Float)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
add, err := Add(g, "add", x, y)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addOp := add.Op
|
||||
|
||||
if out := addOp.NumInputs(); out != 2 {
|
||||
t.Fatalf("Got %d inputs, wanted 2", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationConsumers(t *testing.T) {
|
||||
g := NewGraph()
|
||||
x, err := Placeholder(g, "x", Float)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a, err := Neg(g, "a", x)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b, err := Neg(g, "b", x)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
consumers := []*Operation{a.Op, b.Op}
|
||||
|
||||
xConsumers := x.Consumers()
|
||||
if out := len(xConsumers); out != 2 {
|
||||
t.Fatalf("Got %d consumers, wanted 2", out)
|
||||
}
|
||||
|
||||
for i, consumer := range xConsumers {
|
||||
got := consumer.Op.Name()
|
||||
want := consumers[i].Name()
|
||||
if got != want {
|
||||
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
|
||||
}
|
||||
|
||||
got = consumer.Producer().Op.Name()
|
||||
want = x.Op.Name()
|
||||
if got != want {
|
||||
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
if len(b.Consumers()) != 0 {
|
||||
t.Fatalf("expected %+v to have no consumers", b)
|
||||
}
|
||||
}
|
||||
|
||||
func forceGC() {
|
||||
var mem runtime.MemStats
|
||||
runtime.ReadMemStats(&mem)
|
||||
|
@ -56,6 +56,10 @@ java_library(
|
||||
srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]),
|
||||
javacopts = JAVACOPTS,
|
||||
resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]),
|
||||
deps = [
|
||||
"@com_google_guava",
|
||||
"@com_squareup_javapoet",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
@ -70,6 +74,7 @@ tf_java_op_gen_srcjar(
|
||||
name = "java_op_gen_sources",
|
||||
api_def_srcs = [
|
||||
"//tensorflow/core/api_def:base_api_def",
|
||||
"//tensorflow/core/api_def:java_api_def",
|
||||
],
|
||||
base_package = "org.tensorflow.op",
|
||||
gen_tool = ":java_op_gen_tool",
|
||||
|
6
tensorflow/java/maven/.gitignore
vendored
6
tensorflow/java/maven/.gitignore
vendored
@ -11,4 +11,10 @@ tensorflow/src
|
||||
tensorflow/target
|
||||
proto/src
|
||||
proto/target
|
||||
hadoop/src
|
||||
hadoop/target
|
||||
spark-connector/src
|
||||
spark-connector/target
|
||||
spark-connector/dependency-reduced-pom.xml
|
||||
spark-connector/spark-warehouse
|
||||
pom.xml.versionsBackup
|
||||
|
@ -53,6 +53,12 @@ There are seven artifacts and thus `pom.xml`s involved in this release:
|
||||
7. [`parentpom`](https://maven.apache.org/pom/index.html): Common settings
|
||||
shared by all of the above.
|
||||
|
||||
8. `hadoop`: The TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop.
|
||||
The source code for this package is available in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/hadoop)
|
||||
|
||||
9. `spark-connector`: A Scala library for loading and storing TensorFlow TFRecord
|
||||
using Apache Spark DataFrames. The source code for this package is available
|
||||
in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector)
|
||||
|
||||
## Updating the release
|
||||
|
||||
|
24
tensorflow/java/maven/hadoop/pom.xml
Normal file
24
tensorflow/java/maven/hadoop/pom.xml
Normal file
@ -0,0 +1,24 @@
|
||||
<project
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<!-- Placeholder pom which is replaced by TensorFlow ecosystem Hadoop pom during build -->
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
|
||||
<artifactId>hadoop</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<scm>
|
||||
<url>https://github.com/tensorflow/ecosystem.git</url>
|
||||
<connection>git@github.com:tensorflow/ecosystem.git</connection>
|
||||
<developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
|
||||
</scm>
|
||||
|
||||
<url>https://github.com/tensorflow/ecosystem/</url>
|
||||
<parent>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>parentpom</artifactId>
|
||||
<version>1.9.0-rc0</version>
|
||||
<relativePath>../</relativePath>
|
||||
</parent>
|
||||
</project>
|
@ -32,6 +32,8 @@
|
||||
<module>libtensorflow_jni_gpu</module>
|
||||
<module>tensorflow</module>
|
||||
<module>proto</module>
|
||||
<module>hadoop</module>
|
||||
<module>spark-connector</module>
|
||||
</modules>
|
||||
|
||||
<!-- Two profiles are used:
|
||||
|
@ -19,6 +19,7 @@
|
||||
|
||||
|
||||
RELEASE_URL_PREFIX="https://storage.googleapis.com/tensorflow/libtensorflow"
|
||||
TF_ECOSYSTEM_URL="https://github.com/tensorflow/ecosystem.git"
|
||||
|
||||
# By default we deploy to both ossrh and bintray. These two
|
||||
# environment variables can be set to skip either repository.
|
||||
@ -44,7 +45,9 @@ clean() {
|
||||
# (though if run inside a clean docker container, there won't be any dirty
|
||||
# artifacts lying around)
|
||||
mvn -q clean
|
||||
rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target libtensorflow/src libtensorflow/target tensorflow-android/target
|
||||
rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target \
|
||||
libtensorflow/src libtensorflow/target tensorflow-android/target proto/src proto/target \
|
||||
hadoop/src hadoop/target spark-connector/src spark-connector/target
|
||||
}
|
||||
|
||||
update_version_in_pom() {
|
||||
@ -183,6 +186,43 @@ generate_java_protos() {
|
||||
rm -rf "${DIR}/proto/tmp"
|
||||
}
|
||||
|
||||
|
||||
# Download the TensorFlow ecosystem source from git.
|
||||
# The pom files from this repo do not inherit from the parent pom so the maven version
|
||||
# is updated for each module.
|
||||
download_tf_ecosystem() {
|
||||
ECOSYSTEM_DIR="/tmp/tensorflow-ecosystem"
|
||||
HADOOP_DIR="${DIR}/hadoop"
|
||||
SPARK_DIR="${DIR}/spark-connector"
|
||||
|
||||
# Clean any previous attempts
|
||||
rm -rf "${ECOSYSTEM_DIR}"
|
||||
|
||||
# Clone the TensorFlow ecosystem project
|
||||
mkdir -p "${ECOSYSTEM_DIR}"
|
||||
cd "${ECOSYSTEM_DIR}"
|
||||
git clone "${TF_ECOSYSTEM_URL}"
|
||||
cd ecosystem
|
||||
git checkout r${TF_VERSION}
|
||||
|
||||
# Copy the TensorFlow Hadoop source
|
||||
cp -r "${ECOSYSTEM_DIR}/ecosystem/hadoop/src" "${HADOOP_DIR}"
|
||||
cp "${ECOSYSTEM_DIR}/ecosystem/hadoop/pom.xml" "${HADOOP_DIR}"
|
||||
cd "${HADOOP_DIR}"
|
||||
update_version_in_pom
|
||||
|
||||
# Copy the TensorFlow Spark connector source
|
||||
cp -r "${ECOSYSTEM_DIR}/ecosystem/spark/spark-tensorflow-connector/src" "${SPARK_DIR}"
|
||||
cp "${ECOSYSTEM_DIR}/ecosystem/spark/spark-tensorflow-connector/pom.xml" "${SPARK_DIR}"
|
||||
cd "${SPARK_DIR}"
|
||||
update_version_in_pom
|
||||
|
||||
# Cleanup
|
||||
rm -rf "${ECOSYSTEM_DIR}"
|
||||
|
||||
cd "${DIR}"
|
||||
}
|
||||
|
||||
# Deploy artifacts using a specific profile.
|
||||
# Arguments:
|
||||
# profile - name of selected profile.
|
||||
@ -240,7 +280,8 @@ cd "${DIR}"
|
||||
# Comment lines out appropriately if debugging/tinkering with the release
|
||||
# process.
|
||||
# gnupg2 is required for signing
|
||||
apt-get -qq update && apt-get -qqq install -y gnupg2
|
||||
apt-get -qq update && apt-get -qqq install -y gnupg2 git
|
||||
|
||||
clean
|
||||
update_version_in_pom
|
||||
download_libtensorflow
|
||||
@ -248,6 +289,8 @@ download_libtensorflow_jni
|
||||
download_libtensorflow_jni_gpu
|
||||
update_tensorflow_android
|
||||
generate_java_protos
|
||||
download_tf_ecosystem
|
||||
|
||||
# Build the release artifacts
|
||||
mvn verify
|
||||
# Push artifacts to repository
|
||||
|
24
tensorflow/java/maven/spark-connector/pom.xml
Normal file
24
tensorflow/java/maven/spark-connector/pom.xml
Normal file
@ -0,0 +1,24 @@
|
||||
<project
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<!-- Placeholder pom which is replaced by TensorFlow ecosystem Spark pom during build -->
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
|
||||
<artifactId>spark-connector</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<scm>
|
||||
<url>https://github.com/tensorflow/ecosystem.git</url>
|
||||
<connection>git@github.com:tensorflow/ecosystem.git</connection>
|
||||
<developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
|
||||
</scm>
|
||||
|
||||
<url>https://github.com/tensorflow/ecosystem/</url>
|
||||
<parent>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>parentpom</artifactId>
|
||||
<version>1.9.0-rc0</version>
|
||||
<relativePath>../</relativePath>
|
||||
</parent>
|
||||
</project>
|
@ -35,7 +35,7 @@ namespace tensorflow {
|
||||
namespace java {
|
||||
namespace {
|
||||
|
||||
const char* kLicense =
|
||||
constexpr const char kLicense[] =
|
||||
"/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
|
||||
"\n"
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n"
|
||||
@ -391,9 +391,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
|
||||
}
|
||||
if (!op.hidden()) {
|
||||
// expose the op in the Ops Graph API only if it is visible
|
||||
op_class.add_annotation(
|
||||
Annotation::Create("Operator", "org.tensorflow.op.annotation")
|
||||
.attributes("group = \"" + endpoint.package() + "\""));
|
||||
Annotation oper_annot =
|
||||
Annotation::Create("Operator", "org.tensorflow.op.annotation");
|
||||
if (endpoint.package() != kDefaultEndpointPackage) {
|
||||
oper_annot.attributes("group = \"" + endpoint.package() + "\"");
|
||||
}
|
||||
op_class.add_annotation(oper_annot);
|
||||
}
|
||||
// create op class file
|
||||
const string op_dir_name = io::JoinPath(
|
||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace java {
|
||||
|
||||
constexpr const char kDefaultEndpointPackage[] = "core";
|
||||
|
||||
class EndpointSpec {
|
||||
public:
|
||||
// A specification for an operation endpoint
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user