Merge changes from github.

Change: 153426348
This commit is contained in:
Patrick Nguyen 2017-04-17 20:41:44 -08:00 committed by TensorFlower Gardener
parent cca1b71352
commit 69a4cf80a1
39 changed files with 1122 additions and 120 deletions

View File

@ -1,14 +1,32 @@
NOTE: Issues that are not bugs or feature requests will be closed. Please ask usage questions on StackOverflow.
Please go to Stack Overflow for help and support. http://stackoverflow.com/questions/tagged/tensorflow
If you open a GitHub issue, here is our policy:
### You must complete this information or else your issue will be closed
1. It must be a bug or feature request.
2. The form below must be filled out.
**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g. fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
------------------------
Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request.
### System Information
- *Have I written custom code (as opposed to using a stock example script provided in TensorFlow)?*:
- *OS Platform and Distribution (i.e. Linux Ubuntu 16.0)*:
- *TensorFlow installed from (source or binary)?*:
- *TensorFlow version*:
- *TensorFlow version* (use command below):
- *Bazel version (if compiling from source)*:
- *CUDA/cuDNN version*:
- *GPU Model and Memory*:
- *Exact command to reproduce*:
You can collect some of this information using our environment capture script https://github.com/tensorflow/tensorflow/blob/master/tools/
You can collect the TensorFlow version with
```sh
python -c "import tensorflow as tf; print (tf.GIT_VERSION, tf.VERSION)"
```
### Describe the problem clearly
### Source Code / Logs

1
tensorflow/contrib/BUILD Normal file → Executable file
View File

@ -28,6 +28,7 @@ py_library(
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/hooks",
"//tensorflow/contrib/image:image_py",
"//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
"//tensorflow/contrib/imperative",
"//tensorflow/contrib/input_pipeline:input_pipeline_py",
"//tensorflow/contrib/integrate:integrate_py",

View File

@ -29,6 +29,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON)
option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF)
option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF)
option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")
@ -198,7 +199,7 @@ if (tensorflow_ENABLE_GPU)
# add cudnn
include_directories(${CUDNN_HOME})
set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES}
${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib)
${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib)
# create cuda_config.h
FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h
@ -219,6 +220,7 @@ if (tensorflow_ENABLE_GPU)
${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_HOME}/include/cudnn.h
${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h
${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h
${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h
DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include
)
include_directories(${tensorflow_source_dir}/third_party/gpus)
@ -244,7 +246,9 @@ include(tf_core_kernels.cmake)
if(tensorflow_ENABLE_GRPC_SUPPORT)
include(tf_core_distributed_runtime.cmake)
endif()
# We include tf_cc_ops first, because tf_c depends on tf_cc.
include(tf_cc_ops.cmake)
include(tf_c.cmake)
if(tensorflow_BUILD_CC_EXAMPLE)
include(tf_tutorials.cmake)
include(tf_label_image_example.cmake)
@ -254,6 +258,9 @@ if(tensorflow_BUILD_PYTHON_BINDINGS)
include(tensorboard)
include(tf_python.cmake)
endif()
if (tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS)
if(tensorflow_BUILD_SHARED_LIB)
include(tf_shared_lib.cmake)
endif()
if(tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS)
include(tf_tests.cmake)
endif()

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
########################################################
# tf_c_framework library
########################################################
set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.h"
)
add_library(tf_c OBJECT ${tf_c_srcs})
add_dependencies(tf_c tf_cc_framework tf_core_lib tf_protos_cc)

View File

@ -19,6 +19,7 @@ set(tf_cc_framework_srcs
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope_internal.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc"
)

View File

@ -116,6 +116,10 @@ if(WIN32)
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc"
# temporarily disable nccl (nccl itself needs to be ported to windows first)
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc"
)
list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_windows_exclude_srcs})
endif(WIN32)

View File

@ -686,19 +686,7 @@ set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_writer.cc"
"${tensorflow_source_dir}/tensorflow/python/util/kernel_registry.h"
"${tensorflow_source_dir}/tensorflow/python/util/kernel_registry.cc"
"${tensorflow_source_dir}/tensorflow/c/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/c_api.h"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/grad_op_registry.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/grad_op_registry.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope_internal.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc"
"${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.cc"
)
@ -715,9 +703,11 @@ if(WIN32)
#
add_library(pywrap_tensorflow_internal_static STATIC
${pywrap_tensorflow_internal_src}
$<TARGET_OBJECTS:tf_c>
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
$<TARGET_OBJECTS:tf_cc>
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
@ -727,33 +717,43 @@ if(WIN32)
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
target_include_directories(pywrap_tensorflow_internal_static PUBLIC
${PYTHON_INCLUDE_DIR}
${NUMPY_INCLUDE_DIR}
)
target_link_libraries(pywrap_tensorflow_internal_static
tf_protos_cc
tf_python_protos_cc
#target_link_libraries(pywrap_tensorflow_internal_static
# tf_protos_cc
# tf_python_protos_cc
#)
add_dependencies(pywrap_tensorflow_internal_static tf_protos_cc tf_python_protos_cc)
set(pywrap_tensorflow_internal_static_dependencies
$<TARGET_FILE:pywrap_tensorflow_internal_static>
$<TARGET_FILE:tf_protos_cc>
$<TARGET_FILE:tf_python_protos_cc>
)
set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow.def")
set_source_files_properties(${pywrap_tensorflow_deffile} PROPERTIES GENERATED TRUE)
add_custom_command(TARGET pywrap_tensorflow_internal_static POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py
--input $<TARGET_FILE:pywrap_tensorflow_internal_static>
--output ${pywrap_tensorflow_deffile}
--input "${pywrap_tensorflow_internal_static_dependencies}"
--output "${pywrap_tensorflow_deffile}"
--target _pywrap_tensorflow_internal.pyd
)
endif(WIN32)
# pywrap_tensorflow_internal is a shared library containing all of the
# TensorFlow runtime and the standard ops and kernels. These are installed into
# tf_python/tensorflow/python/.
add_library(pywrap_tensorflow_internal SHARED
${pywrap_tensorflow_internal_src}
$<TARGET_OBJECTS:tf_c>
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
$<TARGET_OBJECTS:tf_cc>
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
@ -773,7 +773,8 @@ target_include_directories(pywrap_tensorflow_internal PUBLIC
${PYTHON_INCLUDE_DIR}
${NUMPY_INCLUDE_DIR}
)
target_link_libraries(pywrap_tensorflow_internal
target_link_libraries(pywrap_tensorflow_internal PRIVATE
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
tf_protos_cc

View File

@ -0,0 +1,87 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
if(WIN32)
# Windows: build a static library with the same objects as tensorflow.dll.
# This can be used to build for a standalone exe and also helps us to
# find all symbols that need to be exported from the dll which is needed
# to provide the tensorflow c/c++ api in tensorflow.dll.
# From the static library we create the def file with all symbols that need to
# be exported from tensorflow.dll. Because there is a limit of 64K sybmols
# that can be exported, we filter the symbols with a python script to the namespaces
# we need.
#
add_library(tensorflow_static STATIC
$<TARGET_OBJECTS:tf_c>
$<TARGET_OBJECTS:tf_cc>
$<TARGET_OBJECTS:tf_cc_framework>
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
$<TARGET_OBJECTS:tf_core_kernels>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
add_dependencies(tensorflow_static tf_protos_cc)
set(tensorflow_static_dependencies
$<TARGET_FILE:tensorflow_static>
$<TARGET_FILE:tf_protos_cc>
)
set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def")
set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE)
add_custom_command(TARGET tensorflow_static POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py
--input "${tensorflow_static_dependencies}"
--output "${tensorflow_deffile}"
--target tensorflow.dll
)
endif(WIN32)
# tensorflow is a shared library containing all of the
# TensorFlow runtime and the standard ops and kernels.
add_library(tensorflow SHARED
$<TARGET_OBJECTS:tf_c>
$<TARGET_OBJECTS:tf_cc>
$<TARGET_OBJECTS:tf_cc_framework>
$<TARGET_OBJECTS:tf_cc_ops>
$<TARGET_OBJECTS:tf_core_lib>
$<TARGET_OBJECTS:tf_core_cpu>
$<TARGET_OBJECTS:tf_core_framework>
$<TARGET_OBJECTS:tf_core_ops>
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<$<BOOL:${tensorflow_ENABLE_GRPC_SUPPORT}>:$<TARGET_OBJECTS:tf_core_distributed_runtime>>
$<TARGET_OBJECTS:tf_core_kernels>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
${tensorflow_deffile}
)
target_link_libraries(tensorflow PRIVATE
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
tf_protos_cc
)
if(WIN32)
add_dependencies(tensorflow tensorflow_static)
endif(WIN32)

View File

@ -73,10 +73,13 @@ add_executable(${transform_graph}
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<TARGET_OBJECTS:tf_core_kernels>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
target_link_libraries(${transform_graph} PUBLIC
tf_protos_cc
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
)
@ -92,10 +95,13 @@ add_executable(${summarize_graph}
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<TARGET_OBJECTS:tf_core_kernels>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
target_link_libraries(${summarize_graph} PUBLIC
tf_protos_cc
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
)
@ -111,10 +117,13 @@ add_executable(${compare_graphs}
$<TARGET_OBJECTS:tf_core_direct_session>
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
$<TARGET_OBJECTS:tf_core_kernels>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_core_kernels_cpu_only>>
$<$<BOOL:${tensorflow_ENABLE_GPU}>:$<TARGET_OBJECTS:tf_stream_executor>>
)
target_link_libraries(${compare_graphs} PUBLIC
tf_protos_cc
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
)

View File

@ -47,8 +47,16 @@ DUMPBIN = "dumpbin.exe"
EXCLUDE_RE = re.compile(r"deleting destructor|::internal::")
# Include if matched before exclude
INCLUDEPRE_RE = re.compile(r"tensorflow::internal::LogMessage|"
r"tensorflow::internal::CheckOpMessageBuilder")
INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|"
r"tensorflow::internal::LogMessage|"
r"tensorflow::internal::LogString|"
r"tensorflow::internal::CheckOpMessageBuilder|"
r"tensorflow::internal::PickUnusedPortOrDie|"
r"tensorflow::internal::ValidateDevice|"
r"tensorflow::ops::internal::Enter|"
r"tensorflow::strings::internal::AppendPieces|"
r"tensorflow::strings::internal::CatPieces|"
r"tensorflow::io::internal::JoinPathImpl")
# Include if matched after exclude
INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
@ -56,12 +64,27 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|"
r"functor::|"
r"perftools::gputools")
# We want to identify data members explicitly in the DEF file, so that no one
# can implicitly link against the DLL if they use one of the variables exported
# from the DLL and the header they use does not decorate the symbol with
# __declspec(dllimport). It is easier to detect what a data symbol does
# NOT look like, so doing it with the below regex.
DATA_EXCLUDE_RE = re.compile(r"[)(]|"
r"vftable|"
r"vbtable|"
r"vcall|"
r"RTTI|"
r"protobuf::internal::ExplicitlyConstructed")
def get_args():
"""Parse command line."""
filename_list = lambda x: x.split(";")
parser = argparse.ArgumentParser()
parser.add_argument("--input", help="input library", required=True)
parser.add_argument("--input", type=filename_list,
help="paths to input libraries separated by semicolons",
required=True)
parser.add_argument("--output", help="output deffile", required=True)
parser.add_argument("--target", help="name of the target", required=True)
args = parser.parse_args()
return args
@ -70,25 +93,26 @@ def main():
"""main."""
args = get_args()
# Pipe dumpbin to extract all linkable symbols from a lib.
# Pipe dumpbin to extract all linkable symbols from libs.
# Good symbols are collected in candidates and also written to
# a temp file.
candidates = []
tmpfile = tempfile.NamedTemporaryFile(mode="w", delete=False)
proc = subprocess.Popen([DUMPBIN, "/nologo", "/linkermember:1", args.input],
stdout=subprocess.PIPE)
for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"):
cols = line.split()
if len(cols) < 2:
continue
sym = cols[1]
tmpfile.file.write(sym + "\n")
candidates.append(sym)
for lib_path in args.input:
proc = subprocess.Popen([DUMPBIN, "/nologo", "/linkermember:1", lib_path],
stdout=subprocess.PIPE)
for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"):
cols = line.split()
if len(cols) < 2:
continue
sym = cols[1]
tmpfile.file.write(sym + "\n")
candidates.append(sym)
exit_code = proc.wait()
if exit_code != 0:
print("{} failed, exit={}".format(DUMPBIN, exit_code))
return exit_code
tmpfile.file.close()
exit_code = proc.wait()
if exit_code != 0:
print("{} failed, exit={}".format(DUMPBIN, exit_code))
return exit_code
# Run the symbols through undname to get their undecorated name
# so we can filter on something readable.
@ -96,9 +120,8 @@ def main():
# track dupes
taken = set()
# Header for the def file. Since the tensorflow.dll is actually called
# _pywrap_tensorflow.pyd in the python wheel, hint that in the def file.
def_fp.write("LIBRARY _pywrap_tensorflow_internal.pyd\n")
# Header for the def file.
def_fp.write("LIBRARY " + args.target + "\n")
def_fp.write("EXPORTS\n")
def_fp.write("\t ??1OpDef@tensorflow@@UEAA@XZ\n")
@ -118,8 +141,17 @@ def main():
continue
if not INCLUDE_RE.search(line):
continue
def_fp.write("\t" + decorated + "\n")
if "deleting destructor" in line:
# Some of the symbols convered by INCLUDEPRE_RE export deleting
# destructor symbols, which is a bad idea.
# So we filter out such symbols here.
continue
if DATA_EXCLUDE_RE.search(line):
def_fp.write("\t" + decorated + "\n")
else:
def_fp.write("\t" + decorated + " DATA\n")
taken.add(decorated)
exit_code = proc.wait()
if exit_code != 0:

View File

@ -148,6 +148,15 @@ class BernoulliTest(test.TestCase):
p: [0.2, 0.3, 0.4]
}), [[0.2, 0.7, 0.4]])
def testPmfInvalid(self):
p = [0.1, 0.2, 0.7]
with self.test_session():
dist = bernoulli.Bernoulli(probs=p, validate_args=True)
with self.assertRaisesOpError("must be non-negative."):
dist.prob([1, 1, -1]).eval()
with self.assertRaisesOpError("is not less than or equal to 1."):
dist.prob([2, 0, 1]).eval()
def testPmfWithP(self):
p = [[0.2, 0.4], [0.3, 0.6]]
self._testPmf(probs=p)

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@ -120,6 +121,7 @@ class Bernoulli(distribution.Distribution):
return math_ops.cast(sample, self.dtype)
def _log_prob(self, event):
event = self._maybe_assert_valid_sample(event)
# TODO(jaana): The current sigmoid_cross_entropy_with_logits has
# inconsistent behavior for logits = inf/-inf.
event = math_ops.cast(event, self.logits.dtype)
@ -160,6 +162,17 @@ class Bernoulli(distribution.Distribution):
"""Returns `1` if `prob > 0.5` and `0` otherwise."""
return math_ops.cast(self.probs > 0.5, self.dtype)
def _maybe_assert_valid_sample(self, event, check_integer=True):
if not self.validate_args:
return event
event = distribution_util.embed_check_nonnegative_discrete(
event, check_integer=check_integer)
return control_flow_ops.with_dependencies([
check_ops.assert_less_equal(
event, array_ops.ones_like(event),
message="event is not less than or equal to 1."),
], event)
class BernoulliWithSigmoidProbs(Bernoulli):
"""Bernoulli with `probs = nn.sigmoid(logits)`."""

28
tensorflow/contrib/image/BUILD Normal file → Executable file
View File

@ -87,6 +87,7 @@ cuda_py_test(
srcs = ["python/kernel_tests/image_ops_test.py"],
additional_deps = [
":image_py",
":single_image_random_dot_stereograms_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
@ -96,6 +97,33 @@ cuda_py_test(
],
)
tf_custom_op_library(
name = "python/ops/_single_image_random_dot_stereograms.so",
srcs = [
"kernels/single_image_random_dot_stereograms_ops.cc",
"ops/single_image_random_dot_stereograms_ops.cc",
],
)
tf_gen_op_libs(
op_lib_names = ["single_image_random_dot_stereograms_ops"],
)
tf_gen_op_wrapper_py(
name = "single_image_random_dot_stereograms_ops",
deps = [":single_image_random_dot_stereograms_ops_op_lib"],
)
py_library(
name = "single_image_random_dot_stereograms_py",
srcs = glob(["python/ops/single*.py"]) + ["__init__.py"],
data = [":python/ops/_single_image_random_dot_stereograms.so"],
srcs_version = "PY2AND3",
deps = [
":single_image_random_dot_stereograms_ops",
],
)
filegroup(
name = "all_files",
srcs = glob(

2
tensorflow/contrib/image/__init__.py Normal file → Executable file
View File

@ -25,6 +25,7 @@ projective transforms (including rotation) are supported.
@@compose_transforms
@@rotate
@@transform
@@single_image_random_dot_stereograms
"""
from __future__ import absolute_import
from __future__ import division
@ -35,6 +36,7 @@ from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_t
from tensorflow.contrib.image.python.ops.image_ops import compose_transforms
from tensorflow.contrib.image.python.ops.image_ops import rotate
from tensorflow.contrib.image.python.ops.image_ops import transform
from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms
from tensorflow.python.util.all_util import remove_undocumented

View File

@ -0,0 +1,424 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::InferenceContext;
template <typename T>
class SingleImageRandomDotStereogramsOp : public OpKernel {
private:
int E2Epixels; // Pixels from eye to eye = eye_to_eye_inches * DPI
int input_Xvalue; // X value of input Z values (width)
int input_Yvalue; // Y value of input Z values (height)
int output_Ximage; // X value of output image (width)
int output_Yimage; // Y value of output image (height)
int output_Cimage; // color value of output image (color, 1 or 3) (3 not
// implemented)
int data_box_left; // X starting value for DATA window
int data_box_top; // Y starting value for DATA window
int data_box_width; // width of scan line
int data_box_height; // hight of image
int converge_dot_box_end; // Row convergences dots end on
uint8* outputImage; // Output Image flat as a buffer (Tensor Connection)
double* ZBuffer; // For internal use, allow for MASK, etc later, actual Z
// used for Stereogram, XxY (X is the row index, y is col
// index like a screen)
// 0 (far) -> 1.0(near) range
bool hidden_surface_removal;
int convergence_dots_size;
int dots_per_inch;
float eye_separation;
float mu;
bool normalize;
float normalize_max;
float normalize_min;
float border_level;
int number_colors;
::tensorflow::TensorShapeProto output_image_shape;
::tensorflow::TensorShapeProto output_data_window;
uint8 Cblack = (uint8)0;
uint8 Cwhite = (uint8)255;
int indexMode = 0; // 0 - truncate XY, 1 - round XY, 2 - Interpolate XY (not
// implemented yet, keep default of 0)
int interp_x, interp_y; // 1 - yes, 0 - no interpolation directions (not
// implemented yet)
bool debugging = false;
inline int separation(double z) {
return (std::round((1 - mu * z) * E2Epixels / (2 - mu * z)));
}
inline int get_far_width() { return (separation(0.0)); }
inline int get_near_width() { return (separation(1.0)); }
public:
explicit SingleImageRandomDotStereogramsOp(OpKernelConstruction* context)
: OpKernel(context) { // Constructor
OP_REQUIRES_OK(context, context->GetAttr("hidden_surface_removal",
&hidden_surface_removal));
OP_REQUIRES_OK(context, context->GetAttr("convergence_dots_size",
&convergence_dots_size));
OP_REQUIRES_OK(context, context->GetAttr("dots_per_inch", &dots_per_inch));
OP_REQUIRES_OK(context,
context->GetAttr("eye_separation", &eye_separation));
OP_REQUIRES_OK(context, context->GetAttr("mu", &mu));
OP_REQUIRES_OK(context, context->GetAttr("normalize", &normalize));
OP_REQUIRES_OK(context, context->GetAttr("normalize_max", &normalize_max));
OP_REQUIRES_OK(context, context->GetAttr("normalize_min", &normalize_min));
OP_REQUIRES_OK(context, context->GetAttr("border_level", &border_level));
OP_REQUIRES_OK(context, context->GetAttr("number_colors", &number_colors));
OP_REQUIRES_OK(context,
context->GetAttr("output_image_shape", &output_image_shape));
OP_REQUIRES_OK(context,
context->GetAttr("output_data_window", &output_data_window));
E2Epixels =
eye_separation * dots_per_inch; // Initialize pixels from eye to eye
}
~SingleImageRandomDotStereogramsOp() { // Destructor
}
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = context->input(0);
input_Xvalue = input_tensor.shape().dim_size(
1); // X value is the number of columns of the input matrix
input_Yvalue =
input_tensor.shape().dim_size(0); // Y value is the number of rows
output_Ximage = output_image_shape.dim(0).size();
output_Yimage = output_image_shape.dim(1).size();
output_Cimage = output_image_shape.dim(2).size();
if (number_colors > 256) // Go to full color image
output_Cimage = 3;
int data_Xwindow = output_data_window.dim(0).size();
int data_Ywindow = output_data_window.dim(1).size();
int deltaX_border_image = output_Ximage - data_Xwindow;
int deltaY_border_image = output_Yimage - data_Ywindow;
if (convergence_dots_size >
0) // 3 frame sections in Y direction due to DOTS
{
deltaY_border_image =
deltaY_border_image -
convergence_dots_size; // Take off space for Convergence Dots
deltaY_border_image = std::max(0, deltaY_border_image);
data_box_top = deltaY_border_image / 3;
if (deltaY_border_image >= 0) {
converge_dot_box_end = output_Yimage - 1 - data_box_top;
} else {
converge_dot_box_end = output_Yimage - 1;
}
} else // Otherwise only 2, no convergence dot
{
data_box_top = deltaY_border_image / 2; // Center DATA in Y dimension
converge_dot_box_end = output_Yimage - 1;
}
data_box_left = deltaX_border_image / 2; // Center DATA in X dimension
data_box_width = data_Xwindow; // width of scan line
data_box_height = data_Ywindow; // hight of image
const T* inputZ = input_tensor.flat<T>().data(); // Flatten input Z buffer
BuildZBuffer(inputZ);
// Output a scalar string.
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(
context,
context->allocate_output(
0, TensorShape({output_Yimage, output_Ximage, output_Cimage}),
&output_tensor));
outputImage = output_tensor->flat<uint8>().data();
generate_stereogram();
delete[] ZBuffer;
}
//***************************************************************************
//***************************************************************************
// Move input into standard Z format to reduce complexity of algorithm
//
void BuildZBuffer(const T* Z, bool log = false) {
double MaxValue = 1.0;
double MinValue = 0.0;
ZBuffer = new double[input_Xvalue * input_Yvalue]; // Used to computer
// final Z values before
// rendering to output
if (normalize) {
// Init Min/Max to first value
if (normalize_max < normalize_min) // Autoscale if MIN>MAX
{
MaxValue = (double)*Z;
MinValue = (double)*Z;
for (int y = 0; y < input_Yvalue; ++y)
for (int x = 0; x < input_Xvalue; ++x) {
double value = getZfromInputImage(Z, x, y);
if (value > MaxValue) MaxValue = value;
if (value < MinValue) MinValue = value;
}
} else {
MaxValue = normalize_max;
MinValue = normalize_min;
}
}
for (int y = 0; y < input_Yvalue; ++y)
for (int x = 0; x < input_Xvalue; ++x) {
double value = getZfromInputImage(Z, x, y);
if (normalize) {
value = (value - MinValue) / (MaxValue - MinValue);
}
if (value > 1.0) value = 1.0;
if (value < 0.0) value = 0.0;
*(ZBuffer + (input_Xvalue * y + x)) = value;
}
}
//***************************************************************************
//***************************************************************************
double getZfromInputImage(const T* Z, int x, int y) {
double return_val;
return_val = (double)*(Z + input_Xvalue * y + x); // Get value
return return_val;
}
//***************************************************************************
//***************************************************************************
// All normalized, not checking required
// Possible Projection issue if DATA is bigger or smaller than Input
// Modes include:
// Truncate value (Default)
// Round-off value
// Interpolate between values
//
double getZfromZbuffer(double x, double y) {
int xi, yi;
switch (indexMode) {
case 0: // Truncate
xi = int(x);
yi = int(y);
return (*(ZBuffer + (xi + input_Xvalue * yi)));
break;
case 1: // Round-off
xi = std::round(x);
yi = std::round(y);
return (*(ZBuffer + (xi + input_Xvalue * yi)));
break;
case 2: // Interpolate (Not implemented yet, will need 4 points
// [x,y],[x+1,y],[x,y+1],[x+1,y+1], then interpolate)
xi = int(x);
yi = int(y);
return (*(ZBuffer + (xi + input_Xvalue * yi)));
break;
default: // Round-off is the default
xi = int(x + 0.5);
yi = int(y + 0.5);
return (*(ZBuffer + (xi + input_Xvalue * yi)));
break;
}
}
//***************************************************************************
//***************************************************************************
int getOutputImageIndex(int x, int y,
int channel) { // No error checking for some
// optimization, calling routine
// required to make sure there is no
// violation
return ((output_Ximage * output_Cimage) * y + x * output_Cimage + channel);
}
//***************************************************************************
//***************************************************************************
double getZFromOutputPixel(int x, int y) {
double xofz, yofz, returnval;
// Convert pixel units to Z units, do this as "double"
xofz =
(double)input_Xvalue * (x - data_box_left) / ((double)data_box_width);
yofz =
(double)input_Yvalue * (y - data_box_top) / ((double)data_box_height);
if ((xofz < 0) || (yofz < 0) || (yofz >= input_Yvalue) ||
(xofz >= input_Xvalue)) { // Top of left side border hit or Right
// side or bottom border hit
// Send BORDER Z value
return (border_level);
}
{ // in data set Z interpolate if need
double gz;
gz = getZfromZbuffer(xofz, yofz);
returnval = gz;
}
return (returnval);
}
//***************************************************************************
//***************************************************************************
void generate_stereogram() {
int s, left, right, visible, t, l;
double zt, gz;
// Scan line
uint8* pix; // Scan row color for each pixel
int* same; // Used to determine if Pixel needs to be the same as another
// pixel in the row
pix = new uint8[output_Ximage * output_Cimage];
same = new int[output_Ximage];
for (int y = 0; y < output_Yimage; ++y) {
// Set no dependencies on any pixels, tie each one back to itself
for (int x = 0; x < output_Ximage; ++x) same[x] = x;
for (int x = 0; x < output_Ximage; ++x) {
gz = getZFromOutputPixel(x, y);
s = separation(gz);
left = x - s / 2;
right = left + s;
if ((left >= 0) && (right < output_Ximage)) {
t = 1;
visible = 1;
if (hidden_surface_removal) do {
zt = gz + 2 * (2 - mu * gz) * t / (mu * E2Epixels);
visible = (getZFromOutputPixel(x - t, y) < zt) &&
(getZFromOutputPixel(x + t, y) < zt);
++t;
} while ((visible) && (zt < 1));
if (visible) {
l = same[left];
while ((l != left) && (l != right))
if (l < right) {
left = l;
l = same[left];
} else {
same[left] = right;
left = right;
l = same[left];
right = l;
}
same[left] = right;
}
}
}
// Set colors for scan row, use channels and number_colors
for (int x = output_Ximage - 1; x >= 0; x--) {
for (int channel = 0; channel < output_Cimage; ++channel) {
if (same[x] == x) { // Pick a random color
if (number_colors == 2) {
if ((rand() % 2) == 0) {
pix[x * output_Cimage + channel] = Cblack;
} else {
pix[x * output_Cimage + channel] = Cwhite;
}
} else {
pix[x * output_Cimage + channel] = rand() % 256;
}
} else
pix[x * output_Cimage + channel] =
pix[same[x] * output_Cimage + channel];
setpixel(x, y, channel, pix[x * output_Cimage + channel]);
}
}
}
draw_convergence_dots();
delete[] pix;
delete[] same;
}
//***************************************************************************
//***************************************************************************
void draw_convergence_dots() {
int x1, x2; // center position for convergence dots
if (convergence_dots_size == 0) // No dot, return
return;
x1 = output_Ximage / 2 - get_far_width() / 2;
x2 = output_Ximage / 2 + get_far_width() / 2;
for (int lloop = 0; lloop < convergence_dots_size; ++lloop)
for (int wloop = 0; wloop < convergence_dots_size; ++wloop)
for (int channel = 0; channel < output_Cimage; ++channel) {
setpixel(x1 - (convergence_dots_size / 2) + wloop,
converge_dot_box_end - lloop, channel, Cblack);
setpixel(x2 - (convergence_dots_size / 2) + wloop,
converge_dot_box_end - lloop, channel, Cblack);
}
}
//***************************************************************************
//***************************************************************************
void setpixel(int x, int y, int channel, uint8 color) {
*(outputImage + getOutputImageIndex(x, y, channel)) = color;
}
};
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("SingleImageRandomDotStereograms") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
SingleImageRandomDotStereogramsOp<T>);
REGISTER_KERNEL(int32);
REGISTER_KERNEL(int64);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
} // end namespace tensorflow

View File

@ -0,0 +1,93 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::InferenceContext;
REGISTER_OP("SingleImageRandomDotStereograms")
.Attr("T: {double,float,int64,int32}")
.Input("depth_values: T")
.Output("image: uint8")
.Attr("hidden_surface_removal: bool = true")
.Attr("convergence_dots_size: int = 8")
.Attr("dots_per_inch: int = 72")
.Attr("eye_separation: float = 2.5")
.Attr("mu: float = .3333")
.Attr("normalize: bool = true")
.Attr("normalize_max: float = -100.0")
.Attr("normalize_min: float = 100.0")
.Attr("border_level: float = 0.0")
.Attr("number_colors: int = 256")
.Attr(
"output_image_shape: shape = { dim {size:1024} dim {size: 768} dim "
"{size: 1}}")
.Attr("output_data_window: shape = { dim {size:1022} dim {size: 757}}")
.Doc(R"doc(
Outputs a single image random dot stereogram for export via encode_PNG/JPG OP.
Given the 2-D tensor 'depth_values' with encoded Z values, this operation will
encode 3-D data into a 2-D image. The output of this Op is suitable for the
encode_PNG/JPG ops. Be careful with image compression as this may corrupt the
encode 3-D data witin the image.
This Op is based upon:
'http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper'
Example use which outputs a SIRDS image as picture_out.png:
```python
img=[[1,2,3,3,2,1],
[1,2,3,4,5,2],
[1,2,3,4,5,3],
[1,2,3,4,5,4],
[6,5,4,4,5,5]]
session = tf.InteractiveSession()
sirds = single_image_random_dot_stereograms(img,convergence_dots_size=8,number_colors=256,normalize=True)
out = sirds.eval()
png = tf.image.encode_png(out).eval()
with open('picture_out.png', 'wb') as f:
f.write(png)
```
depth_values: Z values of data to encode into 'output_data_window' window,
lower values are further away {0.0 floor(far), 1.0 ceiling(near) after normalization}, must be 2-D tensor
hidden_surface_removal: Activate hidden surface removal
convergence_dots_size: Black dot size in pixels to help view converge image, drawn on bottom of image
dots_per_inch: Output device in dots/inch
eye_separation: Separation between eyes in inches
mu: Depth of field, Fraction of viewing distance (eg. 1/3 = .3333)
normalize: Normalize input data to [0.0, 1.0]
normalize_max: Fix MAX value for Normalization - if < MIN, autoscale
normalize_min: Fix MIN value for Normalization - if > MAX, autoscale
border_level: Value of border depth 0.0 {far} to 1.0 {near}
number_colors: 2 (Black & White),256 (grayscale), and Numbers > 256 (Full Color) are all that are supported currently
output_image_shape: Output size of returned image in X,Y, Channels 1-grayscale, 3 color (1024, 768, 1),
channels will be updated to 3 if 'number_colors' > 256
output_data_window: Size of "DATA" window, must be equal to or smaller than 'output_image_shape', will be centered
and use 'convergence_dots_size' for best fit to avoid overlap if possible
image:= A tensor of size 'output_image_shape' with the encloded 'depth_values'
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,125 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python layer for image_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
_sirds_ops = loader.load_op_library(
resource_loader.get_path_to_datafile(
"_single_image_random_dot_stereograms.so"))
def single_image_random_dot_stereograms(
depth_values,
hidden_surface_removal=None,
convergence_dots_size=None,
dots_per_inch=None,
eye_separation=None, mu=None,
normalize=None, normalize_max=None,
normalize_min=None,
border_level=None,
number_colors=None,
output_image_shape=None,
output_data_window=None):
"""Output a RandomDotStereogram Tensor for export via encode_PNG/JPG OP.
Given the 2-D tensor 'depth_values' with encoded Z values, this operation
will encode 3-D data into a 2-D image. The output of this Op is suitable
for the encode_PNG/JPG ops. Be careful with image compression as this may
corrupt the encode 3-D data witin the image.
Based upon [this paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper).
This outputs a SIRDS image as picture_out.png:
```python
img=[[1,2,3,3,2,1],
[1,2,3,4,5,2],
[1,2,3,4,5,3],
[1,2,3,4,5,4],
[6,5,4,4,5,5]]
session = tf.InteractiveSession()
sirds = single_image_random_dot_stereograms(
img,
convergence_dots_size=8,
number_colors=256,normalize=True)
out = sirds.eval()
png = tf.image.encode_png(out).eval()
with open('picture_out.png', 'wb') as f:
f.write(png)
```
Args:
depth_values: A `Tensor`. Must be one of the following types:
`float64`, `float32`, `int64`, `int32`. Z values of data to encode
into 'output_data_window' window, lower further away {0.0 floor(far),
1.0 ceiling(near) after norm}, must be 2-D tensor
hidden_surface_removal: An optional `bool`. Defaults to `True`.
Activate hidden surface removal
convergence_dots_size: An optional `int`. Defaults to `8`.
Black dot size in pixels to help view converge image, drawn on bottom
of the image
dots_per_inch: An optional `int`. Defaults to `72`.
Output device in dots/inch
eye_separation: An optional `float`. Defaults to `2.5`.
Separation between eyes in inches
mu: An optional `float`. Defaults to `0.3333`.
Depth of field, Fraction of viewing distance (eg. 1/3 = 0.3333)
normalize: An optional `bool`. Defaults to `True`.
Normalize input data to [0.0, 1.0]
normalize_max: An optional `float`. Defaults to `-100`.
Fix MAX value for Normalization (0.0) - if < MIN, autoscale
normalize_min: An optional `float`. Defaults to `100`.
Fix MIN value for Normalization (0.0) - if > MAX, autoscale
border_level: An optional `float`. Defaults to `0`.
Value of bord in depth 0.0 {far} to 1.0 {near}
number_colors: An optional `int`. Defaults to `256`. 2 (Black &
White), 256 (grayscale), and Numbers > 256 (Full Color) are
supported
output_image_shape: An optional `tf.TensorShape` or list of `ints`.
Defaults to shape `[1024, 768, 1]`. Defines output shape of returned
image in '[X,Y, Channels]' 1-grayscale, 3 color; channels will be
updated to 3 if number_colors > 256
output_data_window: An optional `tf.TensorShape` or list of `ints`.
Defaults to `[1022, 757]`. Size of "DATA" window, must be equal to or
smaller than `output_image_shape`, will be centered and use
`convergence_dots_size` for best fit to avoid overlap if possible
Returns:
A `Tensor` of type `uint8` of shape 'output_image_shape' with encoded
'depth_values'
"""
result = _sirds_ops.single_image_random_dot_stereograms(
depth_values=depth_values,
hidden_surface_removal=hidden_surface_removal,
convergence_dots_size=convergence_dots_size,
dots_per_inch=dots_per_inch,
eye_separation=eye_separation, mu=mu,
normalize=normalize,
normalize_max=normalize_max,
normalize_min=normalize_min,
border_level=border_level,
number_colors=number_colors,
output_image_shape=output_image_shape,
output_data_window=output_data_window)
return result
ops.NotDifferentiable("SingleImageRandomDotStereograms")

View File

@ -25,4 +25,5 @@ from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue
from tensorflow.python.estimator.inputs.queues.feeding_functions import _GeneratorFeedFn
from tensorflow.python.estimator.inputs.queues.feeding_functions import _OrderedDictNumpyFeedFn
from tensorflow.python.estimator.inputs.queues.feeding_functions import _PandasFeedFn
from tensorflow.python.estimator.inputs.queues.feeding_functions import _GeneratorFeedFn
# pylint: enable=unused-import

View File

@ -611,7 +611,7 @@ def _create_model_fn_ops(features,
if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
weight_tensor = _weight_tensor(features, weight_column_name)
loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor)
logging_ops.scalar_summary(
summary.scalar(
_summary_key(head_name, mkey.LOSS), weighted_average_loss)
if mode == model_fn.ModeKeys.TRAIN:

View File

@ -124,7 +124,7 @@ class PoissonHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=logits)
self._assert_output_alternatives(model_fn_ops)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["regression_head/loss"])
_assert_no_variables(self)
loss = self._log_poisson_loss(logits, labels)
_assert_metrics(self, loss, {"loss": loss}, model_fn_ops)
@ -150,7 +150,7 @@ class RegressionHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["regression_head/loss"])
_assert_no_variables(self)
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
@ -180,7 +180,7 @@ class RegressionHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["regression_head/loss"])
_assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
def testRegressionWithLogitsAndLogitsInput(self):
@ -208,7 +208,7 @@ class RegressionHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["regression_head/loss"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionWithLabelName(self):
@ -223,7 +223,7 @@ class RegressionHeadTest(test.TestCase):
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["regression_head/loss"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionWithWeights(self):
@ -238,7 +238,7 @@ class RegressionHeadTest(test.TestCase):
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["regression_head/loss"])
_assert_metrics(self, 2. / len(weights), {"loss": 2. / np.sum(weights)},
model_fn_ops)
@ -261,7 +261,7 @@ class RegressionHeadTest(test.TestCase):
expected_trainable=("regression_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
self, ["loss", "regression_head/centered_bias/bias_0"])
self, ["regression_head/loss", "regression_head/centered_bias/bias_0"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionErrorInSparseTensorLabels(self):
@ -330,7 +330,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -347,7 +347,7 @@ class MultiLabelHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn, logits=logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = 1.00320443
_assert_metrics(self, expected_loss, {
"accuracy": 0.,
@ -387,7 +387,7 @@ class MultiLabelHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .69314718
_assert_metrics(self, expected_loss, {
"accuracy": 2. / 3,
@ -432,7 +432,7 @@ class MultiLabelHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -451,7 +451,7 @@ class MultiLabelHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = 1.377779
expected_eval_metrics = {
"accuracy": 1. / 3,
@ -519,7 +519,7 @@ class MultiLabelHeadTest(test.TestCase):
head_lib.no_op_train_fn, logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -539,7 +539,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_metrics(self, .089985214,
self._expected_eval_metrics(2.69956), model_fn_ops)
@ -559,7 +559,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_metrics(self, 0.089985214,
self._expected_eval_metrics(0.089985214), model_fn_ops)
@ -583,7 +583,7 @@ class MultiLabelHeadTest(test.TestCase):
expected_trainable=("multi_label_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self, (
"loss",
"multi_label_head/loss",
"multi_label_head/centered_bias/bias_0",
"multi_label_head/centered_bias/bias_1",
"multi_label_head/centered_bias/bias_2"
@ -608,7 +608,7 @@ class MultiLabelHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=self._logits)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_label_head/loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -674,7 +674,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -702,7 +702,7 @@ class BinaryClassificationHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_loss = .69314718
label_mean = np.mean(self._labels)
_assert_metrics(self, expected_loss, {
@ -738,7 +738,7 @@ class BinaryClassificationHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -817,7 +817,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -838,7 +838,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_logistic_head/loss"])
expected_total_loss = .31326166
_assert_metrics(
self,
@ -871,7 +871,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_logistic_head/loss"])
# logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
# expected_loss is (total_weighted_loss)/1 since htere is 1 nonzero
@ -911,7 +911,8 @@ class BinaryClassificationHeadTest(test.TestCase):
expected_trainable=("binary_logistic_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
self, ["loss", "binary_logistic_head/centered_bias/bias_0"])
self, ["binary_logistic_head/loss",
"binary_logistic_head/centered_bias/bias_0"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -960,7 +961,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -999,7 +1000,7 @@ class MultiClassHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.0986123
_assert_metrics(self, expected_loss, {
"accuracy": 0.,
@ -1050,7 +1051,7 @@ class MultiClassHeadTest(test.TestCase):
expected_trainable=("multi_class_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self,
["loss",
["multi_class_head/loss",
"multi_class_head/centered_bias/bias_0",
"multi_class_head/centered_bias/bias_1",
"multi_class_head/centered_bias/bias_2"])
@ -1068,7 +1069,7 @@ class MultiClassHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -1087,7 +1088,7 @@ class MultiClassHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 3.1698461
expected_eval_metrics = {
"accuracy": 0.,
@ -1126,7 +1127,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss * weight,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -1150,7 +1151,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447 * weight
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -1257,7 +1258,7 @@ class MultiClassHeadTest(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 1.5514447
expected_eval_metrics = {
"accuracy": 0.,
@ -1283,7 +1284,7 @@ class MultiClassHeadTest(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["multi_class_head/loss"])
expected_loss = 0.5514447
expected_eval_metrics = {
"accuracy": 1.,
@ -1322,7 +1323,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_svm_head/loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@ -1352,7 +1353,7 @@ class BinarySvmHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_svm_head/loss"])
expected_loss = 1.
_assert_metrics(self, expected_loss, {
"accuracy": .5,
@ -1384,7 +1385,7 @@ class BinarySvmHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_svm_head/loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@ -1403,7 +1404,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_svm_head/loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@ -1422,7 +1423,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
_assert_summary_tags(self, ["binary_svm_head/loss"])
expected_weighted_sum = np.sum(
np.multiply(weights, self._expected_losses))
_assert_metrics(self, expected_weighted_sum / len(weights), {
@ -1450,7 +1451,8 @@ class BinarySvmHeadTest(test.TestCase):
expected_trainable=("binary_svm_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
self, ["loss", "binary_svm_head/centered_bias/bias_0"])
self, ["binary_svm_head/loss",
"binary_svm_head/centered_bias/bias_0"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,

View File

@ -89,9 +89,9 @@ class ClusterConfig(object):
```
cluster = {'ps': ['host1:2222', 'host2:2222'],
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps({
os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
'task': {'type': 'worker', 'index': 1}}})
'task': {'type': 'worker', 'index': 1}})
config = ClusterConfig()
assert config.master == 'host4:2222'
assert config.task_id == 1

View File

@ -19,8 +19,6 @@ limitations under the License.
#include "tensorflow/core/kernels/segment_reduction_ops.h"
#include <stdio.h>
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

View File

@ -69,7 +69,7 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
errors::InvalidArgument("Self Adjoint Eigen decomposition was not "
"successful. The input might not be valid."));
outputs->at(0) = eig.eigenvalues();
outputs->at(0) = eig.eigenvalues().template cast<Scalar>();
if (compute_v_) {
outputs->at(1) = eig.eigenvectors();
}
@ -81,7 +81,15 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<float>), float);
REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<double>), double);
REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<complex64>),
complex64);
REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<complex128>),
complex128);
REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<float>), float);
REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<double>),
double);
REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<complex64>),
complex64);
REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<complex128>),
complex128);
} // namespace tensorflow

View File

@ -1330,10 +1330,9 @@ this operation will permute `params` accordingly.
`indices` are always validated to be within range. If assigned to GPU,
out-of-bound indices result in safe but unspecified behavior, which may include
raising an error.
`0`, but this may become an error in the future).
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../../images/Gather.png" alt>
<img style="width:100%" src="../../../images/Gather.png" alt>
</div>
)doc");

View File

@ -318,7 +318,7 @@ REGISTER_OP("SelfAdjointEigV2")
.Output("e: T")
.Output("v: T")
.Attr("compute_v: bool = True")
.Attr("T: {double, float}")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(SelfAdjointEigV2ShapeFn)
.Doc(R"doc(
Computes the eigen decomposition of one or more square self-adjoint matrices.

View File

@ -26363,6 +26363,59 @@ op {
summary: "Computes the sum along segments of a tensor."
description: "Read @{$math_ops#segmentation$the section on segmentation} for an explanation of\nsegments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"https://www.tensorflow.org/images/UnsortedSegmentSum.png\" alt>\n</div>"
}
op {
name: "UnsortedSegmentSum"
input_arg {
name: "data"
type_attr: "T"
}
input_arg {
name: "segment_ids"
description: "A tensor whose shape is a prefix of `data.shape`."
type_attr: "Tindices"
}
input_arg {
name: "num_segments"
type: DT_INT32
}
output_arg {
name: "output"
description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`."
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
}
}
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
summary: "Computes the max along segments of a tensor."
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
}
op {
name: "Unstage"
output_arg {

View File

@ -182,10 +182,10 @@ g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC -I $TF_INC -O2
On Mac OS X, the additional flag "-undefined dynamic_lookup" is required when
building the `.so` file.
> Note on gcc version 5: gcc5 uses the new C++
> [ABI](https://gcc.gnu.org/gcc-5/changes.html#libstdcxx). The binary pip
> packages available on the TensorFlow website are built with gcc4 that uses
> the older ABI. If you compile your op library with gcc5, add
> Note on `gcc` version `>=5`: gcc uses the new C++
> [ABI](https://gcc.gnu.org/gcc-5/changes.html#libstdcxx) since version `5`. The binary pip
> packages available on the TensorFlow website are built with `gcc4` that uses
> the older ABI. If you compile your op library with `gcc>=5`, add
> `-D_GLIBCXX_USE_CXX11_ABI=0` to the command line to make the library
> compatible with the older abi.
> Furthermore if you are using TensorFlow package created from source remember to add `-cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"`

View File

@ -298,7 +298,7 @@ invoke the following command:
<pre>$ <b>bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package</b> </pre>
**NOTE on gcc 5 or later:** the binary pip packages available on the TensorFlow website are built with gcc 4, which uses the older ABI. To make your build compatible with the older ABI, you need to add `-cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"` to your `bazel build` command. ABI compatibility allows custom ops built against the TensorFlow pip package to continue to work against your built package.
**NOTE on gcc 5 or later:** the binary pip packages available on the TensorFlow website are built with gcc 4, which uses the older ABI. To make your build compatible with the older ABI, you need to add `--cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"` to your `bazel build` command. ABI compatibility allows custom ops built against the TensorFlow pip package to continue to work against your built package.
<b>Tip:</b> By default, building TensorFlow from sources consumes
a lot of RAM. If RAM is an issue on your system, you may limit RAM usage
@ -367,6 +367,7 @@ of one of the following guides:
* @{$install_linux#CommonInstallationProblems$Installing TensorFlow on Linux}
* @{$install_mac#CommonInstallationProblems$Installing TensorFlow on Mac OS}
* @{$install_windows#CommonInstallationProblems$Installing TensorFlow on Windows}
Beyond the errors documented in those two guides, the following table
notes additional errors specific to building TensorFlow. Note that we

View File

@ -91,8 +91,8 @@ eight-bit computations:
```sh
curl http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz -o /tmp/inceptionv3.tgz
tar xzf /tmp/inceptionv3.tgz -C /tmp/
bazel build tensorflow/tools/quantization/tools:quantize_graph
bazel-bin/tensorflow/tools/quantization/tools/quantize_graph \
bazel build tensorflow/tools/quantization:quantize_graph
bazel-bin/tensorflow/tools/quantization/quantize_graph \
--input=/tmp/classify_image_graph_def.pb \
--output_node_names="softmax" --output=/tmp/quantized_graph.pb \
--mode=eightbit

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example of DNNClassifier for Iris plant dataset, with exponential decay."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is an example of using convolutional networks over characters for DBpedia dataset to predict class from description of an entity.
"""This is an example of using convolutional networks over characters for
DBpedia dataset to predict class from description of an entity.
This model is similar to one described in this paper:
"Character-level Convolutional Networks for Text Classification"
@ -54,7 +55,7 @@ def char_cnn_model(features, target):
# Apply Convolution filtering on input sequence.
conv1 = tf.contrib.layers.convolution2d(
byte_list, N_FILTERS, FILTER_SHAPE1, padding='VALID')
# Add a RELU for non linearity.
# Add a ReLU for non linearity.
conv1 = tf.nn.relu(conv1)
# Max pooling across output of Convolution+Relu.
pool1 = tf.nn.max_pool(

View File

@ -14,7 +14,7 @@
# ==============================================================================
"""A simple MNIST classifier which displays summaries in TensorBoard.
This is an unimpressive MNIST model, but it is a good example of using
This is an unimpressive MNIST model, but it is a good example of using
tf.name_scope to make a graph legible in the TensorBoard graph explorer, and of
naming summary tags so that they are grouped meaningfully in TensorBoard.
@ -78,7 +78,7 @@ def train():
def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
"""Reusable code for making a simple neural net layer.
It does a matrix multiply, bias add, and then uses relu to nonlinearize.
It does a matrix multiply, bias add, and then uses ReLU to nonlinearize.
It also sets up name scoping so that the resultant graph is easy to read,
and adds a number of summary ops.
"""

View File

@ -78,7 +78,7 @@ def _GetSelfAdjointEigTest(dtype_, shape_):
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(dtype_)
a += a.T
a = np.tile(a, batch_shape + (1, 1))
if dtype_ == np.float32:
if dtype_ == np.float32 or dtype_ == np.complex64:
atol = 1e-4
else:
atol = 1e-12
@ -150,13 +150,14 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_):
if __name__ == '__main__':
for dtype in np.float32, np.float64:
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for size in 1, 2, 5, 10:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(size, size) < 10):
shape = batch_dims + (size, size)
name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
setattr(SelfAdjointEigTest, 'testSelfAdjointEig_' + name,
_GetSelfAdjointEigTest(dtype, shape))
setattr(SelfAdjointEigGradTest, 'testSelfAdjointEigGrad_' + name,
_GetSelfAdjointEigGradTest(dtype, shape))
if dtype in [np.float32, np.float64]:
setattr(SelfAdjointEigGradTest, 'testSelfAdjointEigGrad_' + name,
_GetSelfAdjointEigGradTest(dtype, shape))
test.main()

View File

@ -364,6 +364,23 @@ def batch_normalization(inputs,
Sergey Ioffe, Christian Szegedy
Note: the operations which update the `moving_mean` and `moving_variance`
variables will not be added as dependencies of your training operation and so
must be run separately. For example:
```
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
sess.run([train_op, extra_update_ops], ...)
```
Alternatively, add the operations as a dependency to your training operation
manually, and then just run your training operation as normal:
```
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
train_op = optimizer.minimize(loss)
...
sess.run([train_op], ...)
```
Arguments:
inputs: Tensor input.
axis: Integer, the axis that should be normalized (typically the features

View File

@ -512,6 +512,16 @@ def _MaxPoolGrad(op, grad):
data_format=op.get_attr("data_format"))
@ops.RegisterGradient("MaxPoolWithArgmax")
def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
grad,
op.outputs[1],
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"))
@ops.RegisterGradient("MaxPoolGrad")
def _MaxPoolGradGrad(op, grad):
return (array_ops.zeros(

View File

@ -1478,14 +1478,14 @@ def _softmax(logits, compute_op, dim=-1, name=None):
InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
dimension of `logits`.
"""
def _swap_axis(logits, dim_index, last_index):
def _swap_axis(logits, dim_index, last_index, name=None):
"""Swaps logits's dim_index and last_index."""
return array_ops.transpose(logits,
array_ops.concat([
math_ops.range(dim_index), [last_index],
math_ops.range(dim_index + 1, last_index),
[dim_index]
], 0))
], 0), name=name)
logits = ops.convert_to_tensor(logits)
@ -1501,8 +1501,8 @@ def _softmax(logits, compute_op, dim=-1, name=None):
if is_last_dim:
input_shape = array_ops.shape(logits)
logits = _flatten_outer_dims(logits)
output = compute_op(logits, name=name)
output = array_ops.reshape(output, input_shape)
output = compute_op(logits)
output = array_ops.reshape(output, input_shape, name=name)
return output
# If dim is not the last dimension, we have to do a reshape and transpose so
@ -1517,11 +1517,11 @@ def _softmax(logits, compute_op, dim=-1, name=None):
logits = _flatten_outer_dims(logits)
# Do the actual softmax on its last dimension.
output = compute_op(logits, name=name)
output = compute_op(logits)
# Transform back the output tensor.
output = array_ops.reshape(output, shape_after_swap)
output = _swap_axis(output, dim, math_ops.subtract(input_rank, 1))
output = _swap_axis(output, dim, math_ops.subtract(input_rank, 1), name=name)
# Make shape inference work since reshape and transpose may erase its static
# shape.

View File

@ -24,6 +24,7 @@ import time
import six
from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import gfile
from tensorflow.python.util import compat
@ -67,14 +68,20 @@ class EventFileWriter(object):
self._event_queue = six.moves.queue.Queue(max_queue)
self._ev_writer = pywrap_tensorflow.EventsWriter(
compat.as_bytes(os.path.join(self._logdir, "events")))
self._flush_secs = flush_secs
self._sentinel_event = self._get_sentinel_event()
if filename_suffix:
self._ev_writer.InitWithSuffix(compat.as_bytes(filename_suffix))
self._closed = False
self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
flush_secs)
self._flush_secs, self._sentinel_event)
self._worker.start()
def _get_sentinel_event(self):
"""Generate a sentinel event for terminating worker."""
return event_pb2.Event()
def get_logdir(self):
"""Returns the directory where event file will be written."""
return self._logdir
@ -88,6 +95,9 @@ class EventFileWriter(object):
Does nothing if the EventFileWriter was not closed.
"""
if self._closed:
self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
self._flush_secs, self._sentinel_event)
self._worker.start()
self._closed = False
def add_event(self, event):
@ -113,7 +123,9 @@ class EventFileWriter(object):
Call this method when you do not need the summary writer anymore.
"""
self.add_event(self._sentinel_event)
self.flush()
self._worker.join()
self._ev_writer.Close()
self._closed = True
@ -121,7 +133,7 @@ class EventFileWriter(object):
class _EventLoggerThread(threading.Thread):
"""Thread that logs events."""
def __init__(self, queue, ev_writer, flush_secs):
def __init__(self, queue, ev_writer, flush_secs, sentinel_event):
"""Creates an _EventLoggerThread.
Args:
@ -130,6 +142,8 @@ class _EventLoggerThread(threading.Thread):
the visualizer.
flush_secs: How often, in seconds, to flush the
pending file to disk.
sentinel_event: A sentinel element in queue that tells this thread to
terminate.
"""
threading.Thread.__init__(self)
self.daemon = True
@ -138,10 +152,14 @@ class _EventLoggerThread(threading.Thread):
self._flush_secs = flush_secs
# The first event will be flushed immediately.
self._next_event_flush_time = 0
self._sentinel_event = sentinel_event
def run(self):
while True:
event = self._queue.get()
if event is self._sentinel_event:
self._queue.task_done()
break
try:
self._ev_writer.WriteEvent(event)
# Flush the event writer every so often.

View File

@ -258,6 +258,15 @@ class SummaryWriterTestCase(test.TestCase):
# We should be done.
self.assertRaises(StopIteration, lambda: next(rr))
def testNonBlockingClose(self):
test_dir = self._CleanTestDir("non_blocking_close")
sw = writer.FileWriter(test_dir)
# Sleep 1.2 seconds to make sure event queue is empty.
time.sleep(1.2)
time_before_close = time.time()
sw.close()
self._assertRecent(time_before_close)
# Checks that values returned from session Run() calls are added correctly to
# summaries. These are numpy types so we need to check they fit in the
# protocol buffers correctly.

View File

@ -103,7 +103,7 @@ output layers of the model are. The best source for these is the model training
process, where for a classifier the inputs will be the nodes that receive the
data from the training set, and the output will be the predictions. If you're
unsure, the
[summarize_graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/summarize_graph_main.cc)
[`summarize_graph`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/summarize_graph_main.cc)
tool can inspect the model and provide guesses about likely input and output nodes,
as well as other information that's useful for debugging. Here's an example of
how to use it on the [Inception V3
@ -315,7 +315,7 @@ themselves contain commas (for example shape definitions).
The --inputs and --outputs are shared across all transforms, since it's common
to need to know what the ingoing and outgoing nodes in the graph are. You should
make sure you set these correctly before calling the graph transform tool, and
if you're in doubt check with the model's author, or use the `check_graph` tool
if you're in doubt check with the model's author, or use the [`summarize_graph`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms#inspecting-graphs) tool
to examine likely inputs and outputs.
All transforms can be passed the `ignore_errors` flag, with the value set to